Formatting (#2750)
This commit is contained in:
1
.github/workflows/ci.yml
vendored
1
.github/workflows/ci.yml
vendored
@@ -56,6 +56,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
make install_all
|
make install_all
|
||||||
pip install -e ".[test]"
|
pip install -e ".[test]"
|
||||||
|
pip install pinecone pinecone-text
|
||||||
if: steps.cached-hatch-dependencies.outputs.cache-hit != 'true'
|
if: steps.cached-hatch-dependencies.outputs.cache-hit != 'true'
|
||||||
- name: Run Formatting
|
- name: Run Formatting
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
"import anthropic\n",
|
"import anthropic\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Set up environment variables\n",
|
"# Set up environment variables\n",
|
||||||
"os.environ[\"OPENAI_API_KEY\"] = \"your_openai_api_key\" # needed for embedding model\n",
|
"os.environ[\"OPENAI_API_KEY\"] = \"your_openai_api_key\" # needed for embedding model\n",
|
||||||
"os.environ[\"ANTHROPIC_API_KEY\"] = \"your_anthropic_api_key\""
|
"os.environ[\"ANTHROPIC_API_KEY\"] = \"your_anthropic_api_key\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -33,7 +33,7 @@
|
|||||||
" \"model\": \"claude-3-5-sonnet-latest\",\n",
|
" \"model\": \"claude-3-5-sonnet-latest\",\n",
|
||||||
" \"temperature\": 0.1,\n",
|
" \"temperature\": 0.1,\n",
|
||||||
" \"max_tokens\": 2000,\n",
|
" \"max_tokens\": 2000,\n",
|
||||||
" }\n",
|
" },\n",
|
||||||
" }\n",
|
" }\n",
|
||||||
" }\n",
|
" }\n",
|
||||||
" self.client = anthropic.Client(api_key=os.environ[\"ANTHROPIC_API_KEY\"])\n",
|
" self.client = anthropic.Client(api_key=os.environ[\"ANTHROPIC_API_KEY\"])\n",
|
||||||
@@ -50,11 +50,7 @@
|
|||||||
" - Keep track of open issues and follow-ups\n",
|
" - Keep track of open issues and follow-ups\n",
|
||||||
" \"\"\"\n",
|
" \"\"\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def store_customer_interaction(self,\n",
|
" def store_customer_interaction(self, user_id: str, message: str, response: str, metadata: Dict = None):\n",
|
||||||
" user_id: str,\n",
|
|
||||||
" message: str,\n",
|
|
||||||
" response: str,\n",
|
|
||||||
" metadata: Dict = None):\n",
|
|
||||||
" \"\"\"Store customer interaction in memory.\"\"\"\n",
|
" \"\"\"Store customer interaction in memory.\"\"\"\n",
|
||||||
" if metadata is None:\n",
|
" if metadata is None:\n",
|
||||||
" metadata = {}\n",
|
" metadata = {}\n",
|
||||||
@@ -63,24 +59,17 @@
|
|||||||
" metadata[\"timestamp\"] = datetime.now().isoformat()\n",
|
" metadata[\"timestamp\"] = datetime.now().isoformat()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Format conversation for storage\n",
|
" # Format conversation for storage\n",
|
||||||
" conversation = [\n",
|
" conversation = [{\"role\": \"user\", \"content\": message}, {\"role\": \"assistant\", \"content\": response}]\n",
|
||||||
" {\"role\": \"user\", \"content\": message},\n",
|
|
||||||
" {\"role\": \"assistant\", \"content\": response}\n",
|
|
||||||
" ]\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" # Store in Mem0\n",
|
" # Store in Mem0\n",
|
||||||
" self.memory.add(\n",
|
" self.memory.add(conversation, user_id=user_id, metadata=metadata)\n",
|
||||||
" conversation,\n",
|
|
||||||
" user_id=user_id,\n",
|
|
||||||
" metadata=metadata\n",
|
|
||||||
" )\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" def get_relevant_history(self, user_id: str, query: str) -> List[Dict]:\n",
|
" def get_relevant_history(self, user_id: str, query: str) -> List[Dict]:\n",
|
||||||
" \"\"\"Retrieve relevant past interactions.\"\"\"\n",
|
" \"\"\"Retrieve relevant past interactions.\"\"\"\n",
|
||||||
" return self.memory.search(\n",
|
" return self.memory.search(\n",
|
||||||
" query=query,\n",
|
" query=query,\n",
|
||||||
" user_id=user_id,\n",
|
" user_id=user_id,\n",
|
||||||
" limit=5 # Adjust based on needs\n",
|
" limit=5, # Adjust based on needs\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def handle_customer_query(self, user_id: str, query: str) -> str:\n",
|
" def handle_customer_query(self, user_id: str, query: str) -> str:\n",
|
||||||
@@ -112,15 +101,12 @@
|
|||||||
" model=\"claude-3-5-sonnet-latest\",\n",
|
" model=\"claude-3-5-sonnet-latest\",\n",
|
||||||
" messages=[{\"role\": \"user\", \"content\": prompt}],\n",
|
" messages=[{\"role\": \"user\", \"content\": prompt}],\n",
|
||||||
" max_tokens=2000,\n",
|
" max_tokens=2000,\n",
|
||||||
" temperature=0.1\n",
|
" temperature=0.1,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Store interaction\n",
|
" # Store interaction\n",
|
||||||
" self.store_customer_interaction(\n",
|
" self.store_customer_interaction(\n",
|
||||||
" user_id=user_id,\n",
|
" user_id=user_id, message=query, response=response, metadata={\"type\": \"support_query\"}\n",
|
||||||
" message=query,\n",
|
|
||||||
" response=response,\n",
|
|
||||||
" metadata={\"type\": \"support_query\"}\n",
|
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return response.content[0].text"
|
" return response.content[0].text"
|
||||||
@@ -203,12 +189,12 @@
|
|||||||
" # Get user input\n",
|
" # Get user input\n",
|
||||||
" query = input()\n",
|
" query = input()\n",
|
||||||
" print(\"Customer:\", query)\n",
|
" print(\"Customer:\", query)\n",
|
||||||
" \n",
|
"\n",
|
||||||
" # Check if user wants to exit\n",
|
" # Check if user wants to exit\n",
|
||||||
" if query.lower() == 'exit':\n",
|
" if query.lower() == \"exit\":\n",
|
||||||
" print(\"Thank you for using our support service. Goodbye!\")\n",
|
" print(\"Thank you for using our support service. Goodbye!\")\n",
|
||||||
" break\n",
|
" break\n",
|
||||||
" \n",
|
"\n",
|
||||||
" # Handle the query and print the response\n",
|
" # Handle the query and print the response\n",
|
||||||
" response = chatbot.handle_customer_query(user_id, query)\n",
|
" response = chatbot.handle_customer_query(user_id, query)\n",
|
||||||
" print(\"Support:\", response, \"\\n\\n\")"
|
" print(\"Support:\", response, \"\\n\\n\")"
|
||||||
|
|||||||
@@ -25,7 +25,8 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"# Set up ENV Vars\n",
|
"# Set up ENV Vars\n",
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"os.environ['OPENAI_API_KEY'] = \"sk-xxx\"\n"
|
"\n",
|
||||||
|
"os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -133,11 +134,9 @@
|
|||||||
"assistant_id = os.environ.get(\"ASSISTANT_ID\", None)\n",
|
"assistant_id = os.environ.get(\"ASSISTANT_ID\", None)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# LLM Configuration\n",
|
"# LLM Configuration\n",
|
||||||
"CACHE_SEED = 42 # choose your poison\n",
|
"CACHE_SEED = 42 # choose your poison\n",
|
||||||
"llm_config = {\n",
|
"llm_config = {\n",
|
||||||
" \"config_list\": [\n",
|
" \"config_list\": [{\"model\": \"gpt-4o\", \"api_key\": os.environ[\"OPENAI_API_KEY\"]}],\n",
|
||||||
" {\"model\": \"gpt-4o\", \"api_key\": os.environ[\"OPENAI_API_KEY\"]}\n",
|
|
||||||
" ],\n",
|
|
||||||
" \"cache_seed\": CACHE_SEED,\n",
|
" \"cache_seed\": CACHE_SEED,\n",
|
||||||
" \"timeout\": 120,\n",
|
" \"timeout\": 120,\n",
|
||||||
" \"temperature\": 0.0,\n",
|
" \"temperature\": 0.0,\n",
|
||||||
@@ -348,7 +347,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"# Retrieve the memory\n",
|
"# Retrieve the memory\n",
|
||||||
"relevant_memories = MEM0_MEMORY_CLIENT.search(user_query, user_id=USER_ID, limit=3)\n",
|
"relevant_memories = MEM0_MEMORY_CLIENT.search(user_query, user_id=USER_ID, limit=3)\n",
|
||||||
"relevant_memories_text = '\\n'.join(mem['memory'] for mem in relevant_memories)\n",
|
"relevant_memories_text = \"\\n\".join(mem[\"memory\"] for mem in relevant_memories)\n",
|
||||||
"print(\"Relevant memories:\")\n",
|
"print(\"Relevant memories:\")\n",
|
||||||
"print(relevant_memories_text)\n",
|
"print(relevant_memories_text)\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -389,8 +388,8 @@
|
|||||||
"# - Enables more context-aware and personalized agent responses.\n",
|
"# - Enables more context-aware and personalized agent responses.\n",
|
||||||
"# - Bridges the gap between human input and AI processing in complex workflows.\n",
|
"# - Bridges the gap between human input and AI processing in complex workflows.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"class Mem0ProxyCoderAgent(UserProxyAgent):\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
|
"class Mem0ProxyCoderAgent(UserProxyAgent):\n",
|
||||||
" def __init__(self, *args, **kwargs):\n",
|
" def __init__(self, *args, **kwargs):\n",
|
||||||
" super().__init__(*args, **kwargs)\n",
|
" super().__init__(*args, **kwargs)\n",
|
||||||
" self.memory = MEM0_MEMORY_CLIENT\n",
|
" self.memory = MEM0_MEMORY_CLIENT\n",
|
||||||
@@ -399,15 +398,14 @@
|
|||||||
" def initiate_chat(self, assistant, message):\n",
|
" def initiate_chat(self, assistant, message):\n",
|
||||||
" # Retrieve memory for the agent\n",
|
" # Retrieve memory for the agent\n",
|
||||||
" agent_memories = self.memory.search(message, agent_id=self.agent_id, limit=3)\n",
|
" agent_memories = self.memory.search(message, agent_id=self.agent_id, limit=3)\n",
|
||||||
" agent_memories_txt = '\\n'.join(mem['memory'] for mem in agent_memories)\n",
|
" agent_memories_txt = \"\\n\".join(mem[\"memory\"] for mem in agent_memories)\n",
|
||||||
" prompt = f\"{message}\\n Coding Preferences: \\n{str(agent_memories_txt)}\"\n",
|
" prompt = f\"{message}\\n Coding Preferences: \\n{str(agent_memories_txt)}\"\n",
|
||||||
" response = super().initiate_chat(assistant, message=prompt)\n",
|
" response = super().initiate_chat(assistant, message=prompt)\n",
|
||||||
" # Add new memory after processing the message\n",
|
" # Add new memory after processing the message\n",
|
||||||
" response_dist = response.__dict__ if not isinstance(response, dict) else response\n",
|
" response_dist = response.__dict__ if not isinstance(response, dict) else response\n",
|
||||||
" MEMORY_DATA = [{\"role\": \"user\", \"content\": message}, {\"role\": \"assistant\", \"content\": response_dist}]\n",
|
" MEMORY_DATA = [{\"role\": \"user\", \"content\": message}, {\"role\": \"assistant\", \"content\": response_dist}]\n",
|
||||||
" self.memory.add(MEMORY_DATA, agent_id=self.agent_id)\n",
|
" self.memory.add(MEMORY_DATA, agent_id=self.agent_id)\n",
|
||||||
" return response\n",
|
" return response"
|
||||||
" "
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -560,12 +558,12 @@
|
|||||||
"from cookbooks.helper.mem0_teachability import Mem0Teachability\n",
|
"from cookbooks.helper.mem0_teachability import Mem0Teachability\n",
|
||||||
"\n",
|
"\n",
|
||||||
"teachability = Mem0Teachability(\n",
|
"teachability = Mem0Teachability(\n",
|
||||||
" verbosity=2, # for visibility of what's happening\n",
|
" verbosity=2, # for visibility of what's happening\n",
|
||||||
" recall_threshold=0.5,\n",
|
" recall_threshold=0.5,\n",
|
||||||
" reset_db=False, # Use True to force-reset the memo DB, and False to use an existing DB.\n",
|
" reset_db=False, # Use True to force-reset the memo DB, and False to use an existing DB.\n",
|
||||||
" agent_id=AGENT_ID,\n",
|
" agent_id=AGENT_ID,\n",
|
||||||
" memory_client = MEM0_MEMORY_CLIENT,\n",
|
" memory_client=MEM0_MEMORY_CLIENT,\n",
|
||||||
" )\n",
|
")\n",
|
||||||
"teachability.add_to_agent(user_proxy)"
|
"teachability.add_to_agent(user_proxy)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -14,46 +14,47 @@ def process_item(item_data):
|
|||||||
local_results = defaultdict(list)
|
local_results = defaultdict(list)
|
||||||
|
|
||||||
for item in v:
|
for item in v:
|
||||||
gt_answer = str(item['answer'])
|
gt_answer = str(item["answer"])
|
||||||
pred_answer = str(item['response'])
|
pred_answer = str(item["response"])
|
||||||
category = str(item['category'])
|
category = str(item["category"])
|
||||||
question = str(item['question'])
|
question = str(item["question"])
|
||||||
|
|
||||||
# Skip category 5
|
# Skip category 5
|
||||||
if category == '5':
|
if category == "5":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
metrics = calculate_metrics(pred_answer, gt_answer)
|
metrics = calculate_metrics(pred_answer, gt_answer)
|
||||||
bleu_scores = calculate_bleu_scores(pred_answer, gt_answer)
|
bleu_scores = calculate_bleu_scores(pred_answer, gt_answer)
|
||||||
llm_score = evaluate_llm_judge(question, gt_answer, pred_answer)
|
llm_score = evaluate_llm_judge(question, gt_answer, pred_answer)
|
||||||
|
|
||||||
local_results[k].append({
|
local_results[k].append(
|
||||||
"question": question,
|
{
|
||||||
"answer": gt_answer,
|
"question": question,
|
||||||
"response": pred_answer,
|
"answer": gt_answer,
|
||||||
"category": category,
|
"response": pred_answer,
|
||||||
"bleu_score": bleu_scores["bleu1"],
|
"category": category,
|
||||||
"f1_score": metrics["f1"],
|
"bleu_score": bleu_scores["bleu1"],
|
||||||
"llm_score": llm_score
|
"f1_score": metrics["f1"],
|
||||||
})
|
"llm_score": llm_score,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return local_results
|
return local_results
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description='Evaluate RAG results')
|
parser = argparse.ArgumentParser(description="Evaluate RAG results")
|
||||||
parser.add_argument('--input_file', type=str,
|
parser.add_argument(
|
||||||
default="results/rag_results_500_k1.json",
|
"--input_file", type=str, default="results/rag_results_500_k1.json", help="Path to the input dataset file"
|
||||||
help='Path to the input dataset file')
|
)
|
||||||
parser.add_argument('--output_file', type=str,
|
parser.add_argument(
|
||||||
default="evaluation_metrics.json",
|
"--output_file", type=str, default="evaluation_metrics.json", help="Path to save the evaluation results"
|
||||||
help='Path to save the evaluation results')
|
)
|
||||||
parser.add_argument('--max_workers', type=int, default=10,
|
parser.add_argument("--max_workers", type=int, default=10, help="Maximum number of worker threads")
|
||||||
help='Maximum number of worker threads')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
with open(args.input_file, 'r') as f:
|
with open(args.input_file, "r") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
@@ -61,18 +62,16 @@ def main():
|
|||||||
|
|
||||||
# Use ThreadPoolExecutor with specified workers
|
# Use ThreadPoolExecutor with specified workers
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor:
|
with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor:
|
||||||
futures = [executor.submit(process_item, item_data)
|
futures = [executor.submit(process_item, item_data) for item_data in data.items()]
|
||||||
for item_data in data.items()]
|
|
||||||
|
|
||||||
for future in tqdm(concurrent.futures.as_completed(futures),
|
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
|
||||||
total=len(futures)):
|
|
||||||
local_results = future.result()
|
local_results = future.result()
|
||||||
with results_lock:
|
with results_lock:
|
||||||
for k, items in local_results.items():
|
for k, items in local_results.items():
|
||||||
results[k].extend(items)
|
results[k].extend(items)
|
||||||
|
|
||||||
# Save results to JSON file
|
# Save results to JSON file
|
||||||
with open(args.output_file, 'w') as f:
|
with open(args.output_file, "w") as f:
|
||||||
json.dump(results, f, indent=4)
|
json.dump(results, f, indent=4)
|
||||||
|
|
||||||
print(f"Results saved to {args.output_file}")
|
print(f"Results saved to {args.output_file}")
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import json
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
# Load the evaluation metrics data
|
# Load the evaluation metrics data
|
||||||
with open('evaluation_metrics.json', 'r') as f:
|
with open("evaluation_metrics.json", "r") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
# Flatten the data into a list of question items
|
# Flatten the data into a list of question items
|
||||||
@@ -15,28 +15,20 @@ for key in data:
|
|||||||
df = pd.DataFrame(all_items)
|
df = pd.DataFrame(all_items)
|
||||||
|
|
||||||
# Convert category to numeric type
|
# Convert category to numeric type
|
||||||
df['category'] = pd.to_numeric(df['category'])
|
df["category"] = pd.to_numeric(df["category"])
|
||||||
|
|
||||||
# Calculate mean scores by category
|
# Calculate mean scores by category
|
||||||
result = df.groupby('category').agg({
|
result = df.groupby("category").agg({"bleu_score": "mean", "f1_score": "mean", "llm_score": "mean"}).round(4)
|
||||||
'bleu_score': 'mean',
|
|
||||||
'f1_score': 'mean',
|
|
||||||
'llm_score': 'mean'
|
|
||||||
}).round(4)
|
|
||||||
|
|
||||||
# Add count of questions per category
|
# Add count of questions per category
|
||||||
result['count'] = df.groupby('category').size()
|
result["count"] = df.groupby("category").size()
|
||||||
|
|
||||||
# Print the results
|
# Print the results
|
||||||
print("Mean Scores Per Category:")
|
print("Mean Scores Per Category:")
|
||||||
print(result)
|
print(result)
|
||||||
|
|
||||||
# Calculate overall means
|
# Calculate overall means
|
||||||
overall_means = df.agg({
|
overall_means = df.agg({"bleu_score": "mean", "f1_score": "mean", "llm_score": "mean"}).round(4)
|
||||||
'bleu_score': 'mean',
|
|
||||||
'f1_score': 'mean',
|
|
||||||
'llm_score': 'mean'
|
|
||||||
}).round(4)
|
|
||||||
|
|
||||||
print("\nOverall Mean Scores:")
|
print("\nOverall Mean Scores:")
|
||||||
print(overall_means)
|
print(overall_means)
|
||||||
|
|||||||
@@ -33,35 +33,34 @@ Do NOT include both CORRECT and WRONG in your response, or it will break the eva
|
|||||||
Just return the label CORRECT or WRONG in a json format with the key as "label".
|
Just return the label CORRECT or WRONG in a json format with the key as "label".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def evaluate_llm_judge(question, gold_answer, generated_answer):
|
def evaluate_llm_judge(question, gold_answer, generated_answer):
|
||||||
"""Evaluate the generated answer against the gold answer using an LLM judge."""
|
"""Evaluate the generated answer against the gold answer using an LLM judge."""
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
model="gpt-4o-mini",
|
model="gpt-4o-mini",
|
||||||
messages=[{
|
messages=[
|
||||||
"role": "user",
|
{
|
||||||
"content": ACCURACY_PROMPT.format(
|
"role": "user",
|
||||||
question=question,
|
"content": ACCURACY_PROMPT.format(
|
||||||
gold_answer=gold_answer,
|
question=question, gold_answer=gold_answer, generated_answer=generated_answer
|
||||||
generated_answer=generated_answer
|
),
|
||||||
)
|
}
|
||||||
}],
|
],
|
||||||
response_format={"type": "json_object"},
|
response_format={"type": "json_object"},
|
||||||
temperature=0.0
|
temperature=0.0,
|
||||||
)
|
)
|
||||||
label = json.loads(response.choices[0].message.content)['label']
|
label = json.loads(response.choices[0].message.content)["label"]
|
||||||
return 1 if label == "CORRECT" else 0
|
return 1 if label == "CORRECT" else 0
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Main function to evaluate RAG results using LLM judge."""
|
"""Main function to evaluate RAG results using LLM judge."""
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description="Evaluate RAG results using LLM judge")
|
||||||
description='Evaluate RAG results using LLM judge'
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--input_file',
|
"--input_file",
|
||||||
type=str,
|
type=str,
|
||||||
default="results/default_run_v4_k30_new_graph.json",
|
default="results/default_run_v4_k30_new_graph.json",
|
||||||
help='Path to the input dataset file'
|
help="Path to the input dataset file",
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -78,10 +77,10 @@ def main():
|
|||||||
index = 0
|
index = 0
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
for x in v:
|
for x in v:
|
||||||
question = x['question']
|
question = x["question"]
|
||||||
gold_answer = x['answer']
|
gold_answer = x["answer"]
|
||||||
generated_answer = x['response']
|
generated_answer = x["response"]
|
||||||
category = x['category']
|
category = x["category"]
|
||||||
|
|
||||||
# Skip category 5
|
# Skip category 5
|
||||||
if int(category) == 5:
|
if int(category) == 5:
|
||||||
@@ -92,13 +91,15 @@ def main():
|
|||||||
LLM_JUDGE[category].append(label)
|
LLM_JUDGE[category].append(label)
|
||||||
|
|
||||||
# Store the results
|
# Store the results
|
||||||
RESULTS[index].append({
|
RESULTS[index].append(
|
||||||
"question": question,
|
{
|
||||||
"gt_answer": gold_answer,
|
"question": question,
|
||||||
"response": generated_answer,
|
"gt_answer": gold_answer,
|
||||||
"category": category,
|
"response": generated_answer,
|
||||||
"llm_label": label
|
"category": category,
|
||||||
})
|
"llm_label": label,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Save intermediate results
|
# Save intermediate results
|
||||||
with open(output_path, "w") as f:
|
with open(output_path, "w") as f:
|
||||||
@@ -108,8 +109,7 @@ def main():
|
|||||||
print("All categories accuracy:")
|
print("All categories accuracy:")
|
||||||
for cat, results in LLM_JUDGE.items():
|
for cat, results in LLM_JUDGE.items():
|
||||||
if results: # Only print if there are results for this category
|
if results: # Only print if there are results for this category
|
||||||
print(f" Category {cat}: {np.mean(results):.4f} "
|
print(f" Category {cat}: {np.mean(results):.4f} " f"({sum(results)}/{len(results)})")
|
||||||
f"({sum(results)}/{len(results)})")
|
|
||||||
print("------------------------------------------")
|
print("------------------------------------------")
|
||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ Borrowed from https://github.com/WujiangXu/AgenticMemory/blob/main/utils.py
|
|||||||
|
|
||||||
@article{xu2025mem,
|
@article{xu2025mem,
|
||||||
title={A-mem: Agentic memory for llm agents},
|
title={A-mem: Agentic memory for llm agents},
|
||||||
author={Xu, Wujiang and Liang, Zujie and Mei, Kai and Gao, Hang and Tan, Juntao
|
author={Xu, Wujiang and Liang, Zujie and Mei, Kai and Gao, Hang and Tan, Juntao
|
||||||
and Zhang, Yongfeng},
|
and Zhang, Yongfeng},
|
||||||
journal={arXiv preprint arXiv:2502.12110},
|
journal={arXiv preprint arXiv:2502.12110},
|
||||||
year={2025}
|
year={2025}
|
||||||
@@ -26,42 +26,45 @@ from sentence_transformers.util import pytorch_cos_sim
|
|||||||
|
|
||||||
# Download required NLTK data
|
# Download required NLTK data
|
||||||
try:
|
try:
|
||||||
nltk.download('punkt', quiet=True)
|
nltk.download("punkt", quiet=True)
|
||||||
nltk.download('wordnet', quiet=True)
|
nltk.download("wordnet", quiet=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error downloading NLTK data: {e}")
|
print(f"Error downloading NLTK data: {e}")
|
||||||
|
|
||||||
# Initialize SentenceTransformer model (this will be reused)
|
# Initialize SentenceTransformer model (this will be reused)
|
||||||
try:
|
try:
|
||||||
sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
|
sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Could not load SentenceTransformer model: {e}")
|
print(f"Warning: Could not load SentenceTransformer model: {e}")
|
||||||
sentence_model = None
|
sentence_model = None
|
||||||
|
|
||||||
|
|
||||||
def simple_tokenize(text):
|
def simple_tokenize(text):
|
||||||
"""Simple tokenization function."""
|
"""Simple tokenization function."""
|
||||||
# Convert to string if not already
|
# Convert to string if not already
|
||||||
text = str(text)
|
text = str(text)
|
||||||
return text.lower().replace('.', ' ').replace(',', ' ').replace('!', ' ').replace('?', ' ').split()
|
return text.lower().replace(".", " ").replace(",", " ").replace("!", " ").replace("?", " ").split()
|
||||||
|
|
||||||
|
|
||||||
def calculate_rouge_scores(prediction: str, reference: str) -> Dict[str, float]:
|
def calculate_rouge_scores(prediction: str, reference: str) -> Dict[str, float]:
|
||||||
"""Calculate ROUGE scores for prediction against reference."""
|
"""Calculate ROUGE scores for prediction against reference."""
|
||||||
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
|
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
|
||||||
scores = scorer.score(reference, prediction)
|
scores = scorer.score(reference, prediction)
|
||||||
return {
|
return {
|
||||||
'rouge1_f': scores['rouge1'].fmeasure,
|
"rouge1_f": scores["rouge1"].fmeasure,
|
||||||
'rouge2_f': scores['rouge2'].fmeasure,
|
"rouge2_f": scores["rouge2"].fmeasure,
|
||||||
'rougeL_f': scores['rougeL'].fmeasure
|
"rougeL_f": scores["rougeL"].fmeasure,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def calculate_bleu_scores(prediction: str, reference: str) -> Dict[str, float]:
|
def calculate_bleu_scores(prediction: str, reference: str) -> Dict[str, float]:
|
||||||
"""Calculate BLEU scores with different n-gram settings."""
|
"""Calculate BLEU scores with different n-gram settings."""
|
||||||
pred_tokens = nltk.word_tokenize(prediction.lower())
|
pred_tokens = nltk.word_tokenize(prediction.lower())
|
||||||
ref_tokens = [nltk.word_tokenize(reference.lower())]
|
ref_tokens = [nltk.word_tokenize(reference.lower())]
|
||||||
|
|
||||||
weights_list = [(1, 0, 0, 0), (0.5, 0.5, 0, 0), (0.33, 0.33, 0.33, 0), (0.25, 0.25, 0.25, 0.25)]
|
weights_list = [(1, 0, 0, 0), (0.5, 0.5, 0, 0), (0.33, 0.33, 0.33, 0), (0.25, 0.25, 0.25, 0.25)]
|
||||||
smooth = SmoothingFunction().method1
|
smooth = SmoothingFunction().method1
|
||||||
|
|
||||||
scores = {}
|
scores = {}
|
||||||
for n, weights in enumerate(weights_list, start=1):
|
for n, weights in enumerate(weights_list, start=1):
|
||||||
try:
|
try:
|
||||||
@@ -69,26 +72,20 @@ def calculate_bleu_scores(prediction: str, reference: str) -> Dict[str, float]:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error calculating BLEU score: {e}")
|
print(f"Error calculating BLEU score: {e}")
|
||||||
score = 0.0
|
score = 0.0
|
||||||
scores[f'bleu{n}'] = score
|
scores[f"bleu{n}"] = score
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
def calculate_bert_scores(prediction: str, reference: str) -> Dict[str, float]:
|
def calculate_bert_scores(prediction: str, reference: str) -> Dict[str, float]:
|
||||||
"""Calculate BERTScore for semantic similarity."""
|
"""Calculate BERTScore for semantic similarity."""
|
||||||
try:
|
try:
|
||||||
P, R, F1 = bert_score([prediction], [reference], lang='en', verbose=False)
|
P, R, F1 = bert_score([prediction], [reference], lang="en", verbose=False)
|
||||||
return {
|
return {"bert_precision": P.item(), "bert_recall": R.item(), "bert_f1": F1.item()}
|
||||||
'bert_precision': P.item(),
|
|
||||||
'bert_recall': R.item(),
|
|
||||||
'bert_f1': F1.item()
|
|
||||||
}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error calculating BERTScore: {e}")
|
print(f"Error calculating BERTScore: {e}")
|
||||||
return {
|
return {"bert_precision": 0.0, "bert_recall": 0.0, "bert_f1": 0.0}
|
||||||
'bert_precision': 0.0,
|
|
||||||
'bert_recall': 0.0,
|
|
||||||
'bert_f1': 0.0
|
|
||||||
}
|
|
||||||
|
|
||||||
def calculate_meteor_score(prediction: str, reference: str) -> float:
|
def calculate_meteor_score(prediction: str, reference: str) -> float:
|
||||||
"""Calculate METEOR score for the prediction."""
|
"""Calculate METEOR score for the prediction."""
|
||||||
@@ -98,6 +95,7 @@ def calculate_meteor_score(prediction: str, reference: str) -> float:
|
|||||||
print(f"Error calculating METEOR score: {e}")
|
print(f"Error calculating METEOR score: {e}")
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
def calculate_sentence_similarity(prediction: str, reference: str) -> float:
|
def calculate_sentence_similarity(prediction: str, reference: str) -> float:
|
||||||
"""Calculate sentence embedding similarity using SentenceBERT."""
|
"""Calculate sentence embedding similarity using SentenceBERT."""
|
||||||
if sentence_model is None:
|
if sentence_model is None:
|
||||||
@@ -106,7 +104,7 @@ def calculate_sentence_similarity(prediction: str, reference: str) -> float:
|
|||||||
# Encode sentences
|
# Encode sentences
|
||||||
embedding1 = sentence_model.encode([prediction], convert_to_tensor=True)
|
embedding1 = sentence_model.encode([prediction], convert_to_tensor=True)
|
||||||
embedding2 = sentence_model.encode([reference], convert_to_tensor=True)
|
embedding2 = sentence_model.encode([reference], convert_to_tensor=True)
|
||||||
|
|
||||||
# Calculate cosine similarity
|
# Calculate cosine similarity
|
||||||
similarity = pytorch_cos_sim(embedding1, embedding2).item()
|
similarity = pytorch_cos_sim(embedding1, embedding2).item()
|
||||||
return float(similarity)
|
return float(similarity)
|
||||||
@@ -114,6 +112,7 @@ def calculate_sentence_similarity(prediction: str, reference: str) -> float:
|
|||||||
print(f"Error calculating sentence similarity: {e}")
|
print(f"Error calculating sentence similarity: {e}")
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
def calculate_metrics(prediction: str, reference: str) -> Dict[str, float]:
|
def calculate_metrics(prediction: str, reference: str) -> Dict[str, float]:
|
||||||
"""Calculate comprehensive evaluation metrics for a prediction."""
|
"""Calculate comprehensive evaluation metrics for a prediction."""
|
||||||
# Handle empty or None values
|
# Handle empty or None values
|
||||||
@@ -130,31 +129,31 @@ def calculate_metrics(prediction: str, reference: str) -> Dict[str, float]:
|
|||||||
"bleu4": 0.0,
|
"bleu4": 0.0,
|
||||||
"bert_f1": 0.0,
|
"bert_f1": 0.0,
|
||||||
"meteor": 0.0,
|
"meteor": 0.0,
|
||||||
"sbert_similarity": 0.0
|
"sbert_similarity": 0.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Convert to strings if they're not already
|
# Convert to strings if they're not already
|
||||||
prediction = str(prediction).strip()
|
prediction = str(prediction).strip()
|
||||||
reference = str(reference).strip()
|
reference = str(reference).strip()
|
||||||
|
|
||||||
# Calculate exact match
|
# Calculate exact match
|
||||||
exact_match = int(prediction.lower() == reference.lower())
|
exact_match = int(prediction.lower() == reference.lower())
|
||||||
|
|
||||||
# Calculate token-based F1 score
|
# Calculate token-based F1 score
|
||||||
pred_tokens = set(simple_tokenize(prediction))
|
pred_tokens = set(simple_tokenize(prediction))
|
||||||
ref_tokens = set(simple_tokenize(reference))
|
ref_tokens = set(simple_tokenize(reference))
|
||||||
common_tokens = pred_tokens & ref_tokens
|
common_tokens = pred_tokens & ref_tokens
|
||||||
|
|
||||||
if not pred_tokens or not ref_tokens:
|
if not pred_tokens or not ref_tokens:
|
||||||
f1 = 0.0
|
f1 = 0.0
|
||||||
else:
|
else:
|
||||||
precision = len(common_tokens) / len(pred_tokens)
|
precision = len(common_tokens) / len(pred_tokens)
|
||||||
recall = len(common_tokens) / len(ref_tokens)
|
recall = len(common_tokens) / len(ref_tokens)
|
||||||
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
|
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
|
||||||
|
|
||||||
# Calculate all scores
|
# Calculate all scores
|
||||||
bleu_scores = calculate_bleu_scores(prediction, reference)
|
bleu_scores = calculate_bleu_scores(prediction, reference)
|
||||||
|
|
||||||
# Combine all metrics
|
# Combine all metrics
|
||||||
metrics = {
|
metrics = {
|
||||||
"exact_match": exact_match,
|
"exact_match": exact_match,
|
||||||
@@ -164,48 +163,49 @@ def calculate_metrics(prediction: str, reference: str) -> Dict[str, float]:
|
|||||||
|
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
def aggregate_metrics(all_metrics: List[Dict[str, float]], all_categories: List[int]) -> Dict[str, Dict[str, Union[float, Dict[str, float]]]]:
|
|
||||||
|
def aggregate_metrics(
|
||||||
|
all_metrics: List[Dict[str, float]], all_categories: List[int]
|
||||||
|
) -> Dict[str, Dict[str, Union[float, Dict[str, float]]]]:
|
||||||
"""Calculate aggregate statistics for all metrics, split by category."""
|
"""Calculate aggregate statistics for all metrics, split by category."""
|
||||||
if not all_metrics:
|
if not all_metrics:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
# Initialize aggregates for overall and per-category metrics
|
# Initialize aggregates for overall and per-category metrics
|
||||||
aggregates = defaultdict(list)
|
aggregates = defaultdict(list)
|
||||||
category_aggregates = defaultdict(lambda: defaultdict(list))
|
category_aggregates = defaultdict(lambda: defaultdict(list))
|
||||||
|
|
||||||
# Collect all values for each metric, both overall and per category
|
# Collect all values for each metric, both overall and per category
|
||||||
for metrics, category in zip(all_metrics, all_categories):
|
for metrics, category in zip(all_metrics, all_categories):
|
||||||
for metric_name, value in metrics.items():
|
for metric_name, value in metrics.items():
|
||||||
aggregates[metric_name].append(value)
|
aggregates[metric_name].append(value)
|
||||||
category_aggregates[category][metric_name].append(value)
|
category_aggregates[category][metric_name].append(value)
|
||||||
|
|
||||||
# Calculate statistics for overall metrics
|
# Calculate statistics for overall metrics
|
||||||
results = {
|
results = {"overall": {}}
|
||||||
"overall": {}
|
|
||||||
}
|
|
||||||
|
|
||||||
for metric_name, values in aggregates.items():
|
for metric_name, values in aggregates.items():
|
||||||
results["overall"][metric_name] = {
|
results["overall"][metric_name] = {
|
||||||
'mean': statistics.mean(values),
|
"mean": statistics.mean(values),
|
||||||
'std': statistics.stdev(values) if len(values) > 1 else 0.0,
|
"std": statistics.stdev(values) if len(values) > 1 else 0.0,
|
||||||
'median': statistics.median(values),
|
"median": statistics.median(values),
|
||||||
'min': min(values),
|
"min": min(values),
|
||||||
'max': max(values),
|
"max": max(values),
|
||||||
'count': len(values)
|
"count": len(values),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Calculate statistics for each category
|
# Calculate statistics for each category
|
||||||
for category in sorted(category_aggregates.keys()):
|
for category in sorted(category_aggregates.keys()):
|
||||||
results[f"category_{category}"] = {}
|
results[f"category_{category}"] = {}
|
||||||
for metric_name, values in category_aggregates[category].items():
|
for metric_name, values in category_aggregates[category].items():
|
||||||
if values: # Only calculate if we have values for this category
|
if values: # Only calculate if we have values for this category
|
||||||
results[f"category_{category}"][metric_name] = {
|
results[f"category_{category}"][metric_name] = {
|
||||||
'mean': statistics.mean(values),
|
"mean": statistics.mean(values),
|
||||||
'std': statistics.stdev(values) if len(values) > 1 else 0.0,
|
"std": statistics.stdev(values) if len(values) > 1 else 0.0,
|
||||||
'median': statistics.median(values),
|
"median": statistics.median(values),
|
||||||
'min': min(values),
|
"min": min(values),
|
||||||
'max': max(values),
|
"max": max(values),
|
||||||
'count': len(values)
|
"count": len(values),
|
||||||
}
|
}
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|||||||
@@ -144,4 +144,4 @@ ANSWER_PROMPT_ZEP = """
|
|||||||
|
|
||||||
Question: {{question}}
|
Question: {{question}}
|
||||||
Answer:
|
Answer:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -21,23 +21,15 @@ class Experiment:
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description='Run memory experiments')
|
parser = argparse.ArgumentParser(description="Run memory experiments")
|
||||||
parser.add_argument('--technique_type', choices=TECHNIQUES, default='mem0',
|
parser.add_argument("--technique_type", choices=TECHNIQUES, default="mem0", help="Memory technique to use")
|
||||||
help='Memory technique to use')
|
parser.add_argument("--method", choices=METHODS, default="add", help="Method to use")
|
||||||
parser.add_argument('--method', choices=METHODS, default='add',
|
parser.add_argument("--chunk_size", type=int, default=1000, help="Chunk size for processing")
|
||||||
help='Method to use')
|
parser.add_argument("--output_folder", type=str, default="results/", help="Output path for results")
|
||||||
parser.add_argument('--chunk_size', type=int, default=1000,
|
parser.add_argument("--top_k", type=int, default=30, help="Number of top memories to retrieve")
|
||||||
help='Chunk size for processing')
|
parser.add_argument("--filter_memories", action="store_true", default=False, help="Whether to filter memories")
|
||||||
parser.add_argument('--output_folder', type=str, default='results/',
|
parser.add_argument("--is_graph", action="store_true", default=False, help="Whether to use graph-based search")
|
||||||
help='Output path for results')
|
parser.add_argument("--num_chunks", type=int, default=1, help="Number of chunks to process")
|
||||||
parser.add_argument('--top_k', type=int, default=30,
|
|
||||||
help='Number of top memories to retrieve')
|
|
||||||
parser.add_argument('--filter_memories', action='store_true', default=False,
|
|
||||||
help='Whether to filter memories')
|
|
||||||
parser.add_argument('--is_graph', action='store_true', default=False,
|
|
||||||
help='Whether to use graph-based search')
|
|
||||||
parser.add_argument('--num_chunks', type=int, default=1,
|
|
||||||
help='Number of chunks to process')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -46,33 +38,18 @@ def main():
|
|||||||
|
|
||||||
if args.technique_type == "mem0":
|
if args.technique_type == "mem0":
|
||||||
if args.method == "add":
|
if args.method == "add":
|
||||||
memory_manager = MemoryADD(
|
memory_manager = MemoryADD(data_path="dataset/locomo10.json", is_graph=args.is_graph)
|
||||||
data_path='dataset/locomo10.json',
|
|
||||||
is_graph=args.is_graph
|
|
||||||
)
|
|
||||||
memory_manager.process_all_conversations()
|
memory_manager.process_all_conversations()
|
||||||
elif args.method == "search":
|
elif args.method == "search":
|
||||||
output_file_path = os.path.join(
|
output_file_path = os.path.join(
|
||||||
args.output_folder,
|
args.output_folder,
|
||||||
f"mem0_results_top_{args.top_k}_filter_{args.filter_memories}_graph_{args.is_graph}.json"
|
f"mem0_results_top_{args.top_k}_filter_{args.filter_memories}_graph_{args.is_graph}.json",
|
||||||
)
|
)
|
||||||
memory_searcher = MemorySearch(
|
memory_searcher = MemorySearch(output_file_path, args.top_k, args.filter_memories, args.is_graph)
|
||||||
output_file_path,
|
memory_searcher.process_data_file("dataset/locomo10.json")
|
||||||
args.top_k,
|
|
||||||
args.filter_memories,
|
|
||||||
args.is_graph
|
|
||||||
)
|
|
||||||
memory_searcher.process_data_file('dataset/locomo10.json')
|
|
||||||
elif args.technique_type == "rag":
|
elif args.technique_type == "rag":
|
||||||
output_file_path = os.path.join(
|
output_file_path = os.path.join(args.output_folder, f"rag_results_{args.chunk_size}_k{args.num_chunks}.json")
|
||||||
args.output_folder,
|
rag_manager = RAGManager(data_path="dataset/locomo10_rag.json", chunk_size=args.chunk_size, k=args.num_chunks)
|
||||||
f"rag_results_{args.chunk_size}_k{args.num_chunks}.json"
|
|
||||||
)
|
|
||||||
rag_manager = RAGManager(
|
|
||||||
data_path="dataset/locomo10_rag.json",
|
|
||||||
chunk_size=args.chunk_size,
|
|
||||||
k=args.num_chunks
|
|
||||||
)
|
|
||||||
rag_manager.process_all_conversations(output_file_path)
|
rag_manager.process_all_conversations(output_file_path)
|
||||||
elif args.technique_type == "langmem":
|
elif args.technique_type == "langmem":
|
||||||
output_file_path = os.path.join(args.output_folder, "langmem_results.json")
|
output_file_path = os.path.join(args.output_folder, "langmem_results.json")
|
||||||
@@ -85,11 +62,7 @@ def main():
|
|||||||
elif args.method == "search":
|
elif args.method == "search":
|
||||||
output_file_path = os.path.join(args.output_folder, "zep_search_results.json")
|
output_file_path = os.path.join(args.output_folder, "zep_search_results.json")
|
||||||
zep_manager = ZepSearch()
|
zep_manager = ZepSearch()
|
||||||
zep_manager.process_data_file(
|
zep_manager.process_data_file("dataset/locomo10.json", "1", output_file_path)
|
||||||
"dataset/locomo10.json",
|
|
||||||
"1",
|
|
||||||
output_file_path
|
|
||||||
)
|
|
||||||
elif args.technique_type == "openai":
|
elif args.technique_type == "openai":
|
||||||
output_file_path = os.path.join(args.output_folder, "openai_results.json")
|
output_file_path = os.path.join(args.output_folder, "openai_results.json")
|
||||||
openai_manager = OpenAIPredict()
|
openai_manager = OpenAIPredict()
|
||||||
|
|||||||
@@ -28,14 +28,12 @@ def get_answer(question, speaker_1_user_id, speaker_1_memories, speaker_2_user_i
|
|||||||
speaker_1_user_id=speaker_1_user_id,
|
speaker_1_user_id=speaker_1_user_id,
|
||||||
speaker_1_memories=speaker_1_memories,
|
speaker_1_memories=speaker_1_memories,
|
||||||
speaker_2_user_id=speaker_2_user_id,
|
speaker_2_user_id=speaker_2_user_id,
|
||||||
speaker_2_memories=speaker_2_memories
|
speaker_2_memories=speaker_2_memories,
|
||||||
)
|
)
|
||||||
|
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
model=os.getenv("MODEL"),
|
model=os.getenv("MODEL"), messages=[{"role": "system", "content": prompt}], temperature=0.0
|
||||||
messages=[{"role": "system", "content": prompt}],
|
|
||||||
temperature=0.0
|
|
||||||
)
|
)
|
||||||
t2 = time.time()
|
t2 = time.time()
|
||||||
return response.choices[0].message.content, t2 - t1
|
return response.choices[0].message.content, t2 - t1
|
||||||
@@ -59,7 +57,9 @@ def prompt(state):
|
|||||||
|
|
||||||
|
|
||||||
class LangMem:
|
class LangMem:
|
||||||
def __init__(self,):
|
def __init__(
|
||||||
|
self,
|
||||||
|
):
|
||||||
self.store = InMemoryStore(
|
self.store = InMemoryStore(
|
||||||
index={
|
index={
|
||||||
"dims": 1536,
|
"dims": 1536,
|
||||||
@@ -80,18 +80,12 @@ class LangMem:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def add_memory(self, message, config):
|
def add_memory(self, message, config):
|
||||||
return self.agent.invoke(
|
return self.agent.invoke({"messages": [{"role": "user", "content": message}]}, config=config)
|
||||||
{"messages": [{"role": "user", "content": message}]},
|
|
||||||
config=config
|
|
||||||
)
|
|
||||||
|
|
||||||
def search_memory(self, query, config):
|
def search_memory(self, query, config):
|
||||||
try:
|
try:
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
response = self.agent.invoke(
|
response = self.agent.invoke({"messages": [{"role": "user", "content": query}]}, config=config)
|
||||||
{"messages": [{"role": "user", "content": query}]},
|
|
||||||
config=config
|
|
||||||
)
|
|
||||||
t2 = time.time()
|
t2 = time.time()
|
||||||
return response["messages"][-1].content, t2 - t1
|
return response["messages"][-1].content, t2 - t1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -102,7 +96,7 @@ class LangMem:
|
|||||||
class LangMemManager:
|
class LangMemManager:
|
||||||
def __init__(self, dataset_path):
|
def __init__(self, dataset_path):
|
||||||
self.dataset_path = dataset_path
|
self.dataset_path = dataset_path
|
||||||
with open(self.dataset_path, 'r') as f:
|
with open(self.dataset_path, "r") as f:
|
||||||
self.data = json.load(f)
|
self.data = json.load(f)
|
||||||
|
|
||||||
def process_all_conversations(self, output_file_path):
|
def process_all_conversations(self, output_file_path):
|
||||||
@@ -123,7 +117,7 @@ class LangMemManager:
|
|||||||
|
|
||||||
# Identify speakers
|
# Identify speakers
|
||||||
for conv in chat_history:
|
for conv in chat_history:
|
||||||
speakers.add(conv['speaker'])
|
speakers.add(conv["speaker"])
|
||||||
|
|
||||||
if len(speakers) != 2:
|
if len(speakers) != 2:
|
||||||
raise ValueError(f"Expected 2 speakers, got {len(speakers)}")
|
raise ValueError(f"Expected 2 speakers, got {len(speakers)}")
|
||||||
@@ -134,50 +128,52 @@ class LangMemManager:
|
|||||||
# Add memories for each message
|
# Add memories for each message
|
||||||
for conv in tqdm(chat_history, desc=f"Processing messages {key}", leave=False):
|
for conv in tqdm(chat_history, desc=f"Processing messages {key}", leave=False):
|
||||||
message = f"{conv['timestamp']} | {conv['speaker']}: {conv['text']}"
|
message = f"{conv['timestamp']} | {conv['speaker']}: {conv['text']}"
|
||||||
if conv['speaker'] == speaker1:
|
if conv["speaker"] == speaker1:
|
||||||
agent1.add_memory(message, config)
|
agent1.add_memory(message, config)
|
||||||
elif conv['speaker'] == speaker2:
|
elif conv["speaker"] == speaker2:
|
||||||
agent2.add_memory(message, config)
|
agent2.add_memory(message, config)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Expected speaker1 or speaker2, got {conv['speaker']}")
|
raise ValueError(f"Expected speaker1 or speaker2, got {conv['speaker']}")
|
||||||
|
|
||||||
# Process questions
|
# Process questions
|
||||||
for q in tqdm(questions, desc=f"Processing questions {key}", leave=False):
|
for q in tqdm(questions, desc=f"Processing questions {key}", leave=False):
|
||||||
category = q['category']
|
category = q["category"]
|
||||||
|
|
||||||
if int(category) == 5:
|
if int(category) == 5:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
answer = q['answer']
|
answer = q["answer"]
|
||||||
question = q['question']
|
question = q["question"]
|
||||||
response1, speaker1_memory_time = agent1.search_memory(question, config)
|
response1, speaker1_memory_time = agent1.search_memory(question, config)
|
||||||
response2, speaker2_memory_time = agent2.search_memory(question, config)
|
response2, speaker2_memory_time = agent2.search_memory(question, config)
|
||||||
|
|
||||||
generated_answer, response_time = get_answer(
|
generated_answer, response_time = get_answer(question, speaker1, response1, speaker2, response2)
|
||||||
question, speaker1, response1, speaker2, response2
|
|
||||||
)
|
|
||||||
|
|
||||||
result[key].append({
|
result[key].append(
|
||||||
"question": question,
|
{
|
||||||
"answer": answer,
|
"question": question,
|
||||||
"response1": response1,
|
"answer": answer,
|
||||||
"response2": response2,
|
"response1": response1,
|
||||||
"category": category,
|
"response2": response2,
|
||||||
"speaker1_memory_time": speaker1_memory_time,
|
"category": category,
|
||||||
"speaker2_memory_time": speaker2_memory_time,
|
"speaker1_memory_time": speaker1_memory_time,
|
||||||
"response_time": response_time,
|
"speaker2_memory_time": speaker2_memory_time,
|
||||||
'response': generated_answer
|
"response_time": response_time,
|
||||||
})
|
"response": generated_answer,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# Use multiprocessing to process conversations in parallel
|
# Use multiprocessing to process conversations in parallel
|
||||||
with mp.Pool(processes=10) as pool:
|
with mp.Pool(processes=10) as pool:
|
||||||
results = list(tqdm(
|
results = list(
|
||||||
pool.imap(process_conversation, list(self.data.items())),
|
tqdm(
|
||||||
total=len(self.data),
|
pool.imap(process_conversation, list(self.data.items())),
|
||||||
desc="Processing conversations"
|
total=len(self.data),
|
||||||
))
|
desc="Processing conversations",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Combine results from all workers
|
# Combine results from all workers
|
||||||
for result in results:
|
for result in results:
|
||||||
@@ -185,5 +181,5 @@ class LangMemManager:
|
|||||||
OUTPUT[key].extend(items)
|
OUTPUT[key].extend(items)
|
||||||
|
|
||||||
# Save final results
|
# Save final results
|
||||||
with open(output_file_path, 'w') as f:
|
with open(output_file_path, "w") as f:
|
||||||
json.dump(OUTPUT, f, indent=4)
|
json.dump(OUTPUT, f, indent=4)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ load_dotenv()
|
|||||||
|
|
||||||
|
|
||||||
# Update custom instructions
|
# Update custom instructions
|
||||||
custom_instructions ="""
|
custom_instructions = """
|
||||||
Generate personal memories that follow these guidelines:
|
Generate personal memories that follow these guidelines:
|
||||||
|
|
||||||
1. Each memory should be self-contained with complete context, including:
|
1. Each memory should be self-contained with complete context, including:
|
||||||
@@ -47,7 +47,7 @@ class MemoryADD:
|
|||||||
self.mem0_client = MemoryClient(
|
self.mem0_client = MemoryClient(
|
||||||
api_key=os.getenv("MEM0_API_KEY"),
|
api_key=os.getenv("MEM0_API_KEY"),
|
||||||
org_id=os.getenv("MEM0_ORGANIZATION_ID"),
|
org_id=os.getenv("MEM0_ORGANIZATION_ID"),
|
||||||
project_id=os.getenv("MEM0_PROJECT_ID")
|
project_id=os.getenv("MEM0_PROJECT_ID"),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mem0_client.update_project(custom_instructions=custom_instructions)
|
self.mem0_client.update_project(custom_instructions=custom_instructions)
|
||||||
@@ -59,15 +59,16 @@ class MemoryADD:
|
|||||||
self.load_data()
|
self.load_data()
|
||||||
|
|
||||||
def load_data(self):
|
def load_data(self):
|
||||||
with open(self.data_path, 'r') as f:
|
with open(self.data_path, "r") as f:
|
||||||
self.data = json.load(f)
|
self.data = json.load(f)
|
||||||
return self.data
|
return self.data
|
||||||
|
|
||||||
def add_memory(self, user_id, message, metadata, retries=3):
|
def add_memory(self, user_id, message, metadata, retries=3):
|
||||||
for attempt in range(retries):
|
for attempt in range(retries):
|
||||||
try:
|
try:
|
||||||
_ = self.mem0_client.add(message, user_id=user_id, version="v2",
|
_ = self.mem0_client.add(
|
||||||
metadata=metadata, enable_graph=self.is_graph)
|
message, user_id=user_id, version="v2", metadata=metadata, enable_graph=self.is_graph
|
||||||
|
)
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if attempt < retries - 1:
|
if attempt < retries - 1:
|
||||||
@@ -78,13 +79,13 @@ class MemoryADD:
|
|||||||
|
|
||||||
def add_memories_for_speaker(self, speaker, messages, timestamp, desc):
|
def add_memories_for_speaker(self, speaker, messages, timestamp, desc):
|
||||||
for i in tqdm(range(0, len(messages), self.batch_size), desc=desc):
|
for i in tqdm(range(0, len(messages), self.batch_size), desc=desc):
|
||||||
batch_messages = messages[i:i+self.batch_size]
|
batch_messages = messages[i : i + self.batch_size]
|
||||||
self.add_memory(speaker, batch_messages, metadata={"timestamp": timestamp})
|
self.add_memory(speaker, batch_messages, metadata={"timestamp": timestamp})
|
||||||
|
|
||||||
def process_conversation(self, item, idx):
|
def process_conversation(self, item, idx):
|
||||||
conversation = item['conversation']
|
conversation = item["conversation"]
|
||||||
speaker_a = conversation['speaker_a']
|
speaker_a = conversation["speaker_a"]
|
||||||
speaker_b = conversation['speaker_b']
|
speaker_b = conversation["speaker_b"]
|
||||||
|
|
||||||
speaker_a_user_id = f"{speaker_a}_{idx}"
|
speaker_a_user_id = f"{speaker_a}_{idx}"
|
||||||
speaker_b_user_id = f"{speaker_b}_{idx}"
|
speaker_b_user_id = f"{speaker_b}_{idx}"
|
||||||
@@ -94,7 +95,7 @@ class MemoryADD:
|
|||||||
self.mem0_client.delete_all(user_id=speaker_b_user_id)
|
self.mem0_client.delete_all(user_id=speaker_b_user_id)
|
||||||
|
|
||||||
for key in conversation.keys():
|
for key in conversation.keys():
|
||||||
if key in ['speaker_a', 'speaker_b'] or "date" in key or "timestamp" in key:
|
if key in ["speaker_a", "speaker_b"] or "date" in key or "timestamp" in key:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
date_time_key = key + "_date_time"
|
date_time_key = key + "_date_time"
|
||||||
@@ -104,10 +105,10 @@ class MemoryADD:
|
|||||||
messages = []
|
messages = []
|
||||||
messages_reverse = []
|
messages_reverse = []
|
||||||
for chat in chats:
|
for chat in chats:
|
||||||
if chat['speaker'] == speaker_a:
|
if chat["speaker"] == speaker_a:
|
||||||
messages.append({"role": "user", "content": f"{speaker_a}: {chat['text']}"})
|
messages.append({"role": "user", "content": f"{speaker_a}: {chat['text']}"})
|
||||||
messages_reverse.append({"role": "assistant", "content": f"{speaker_a}: {chat['text']}"})
|
messages_reverse.append({"role": "assistant", "content": f"{speaker_a}: {chat['text']}"})
|
||||||
elif chat['speaker'] == speaker_b:
|
elif chat["speaker"] == speaker_b:
|
||||||
messages.append({"role": "assistant", "content": f"{speaker_b}: {chat['text']}"})
|
messages.append({"role": "assistant", "content": f"{speaker_b}: {chat['text']}"})
|
||||||
messages_reverse.append({"role": "user", "content": f"{speaker_b}: {chat['text']}"})
|
messages_reverse.append({"role": "user", "content": f"{speaker_b}: {chat['text']}"})
|
||||||
else:
|
else:
|
||||||
@@ -116,11 +117,11 @@ class MemoryADD:
|
|||||||
# add memories for the two users on different threads
|
# add memories for the two users on different threads
|
||||||
thread_a = threading.Thread(
|
thread_a = threading.Thread(
|
||||||
target=self.add_memories_for_speaker,
|
target=self.add_memories_for_speaker,
|
||||||
args=(speaker_a_user_id, messages, timestamp, "Adding Memories for Speaker A")
|
args=(speaker_a_user_id, messages, timestamp, "Adding Memories for Speaker A"),
|
||||||
)
|
)
|
||||||
thread_b = threading.Thread(
|
thread_b = threading.Thread(
|
||||||
target=self.add_memories_for_speaker,
|
target=self.add_memories_for_speaker,
|
||||||
args=(speaker_b_user_id, messages_reverse, timestamp, "Adding Memories for Speaker B")
|
args=(speaker_b_user_id, messages_reverse, timestamp, "Adding Memories for Speaker B"),
|
||||||
)
|
)
|
||||||
|
|
||||||
thread_a.start()
|
thread_a.start()
|
||||||
@@ -134,10 +135,7 @@ class MemoryADD:
|
|||||||
if not self.data:
|
if not self.data:
|
||||||
raise ValueError("No data loaded. Please set data_path and call load_data() first.")
|
raise ValueError("No data loaded. Please set data_path and call load_data() first.")
|
||||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
futures = [
|
futures = [executor.submit(self.process_conversation, item, idx) for idx, item in enumerate(self.data)]
|
||||||
executor.submit(self.process_conversation, item, idx)
|
|
||||||
for idx, item in enumerate(self.data)
|
|
||||||
]
|
|
||||||
|
|
||||||
for future in futures:
|
for future in futures:
|
||||||
future.result()
|
future.result()
|
||||||
|
|||||||
@@ -16,12 +16,11 @@ load_dotenv()
|
|||||||
|
|
||||||
|
|
||||||
class MemorySearch:
|
class MemorySearch:
|
||||||
|
def __init__(self, output_path="results.json", top_k=10, filter_memories=False, is_graph=False):
|
||||||
def __init__(self, output_path='results.json', top_k=10, filter_memories=False, is_graph=False):
|
|
||||||
self.mem0_client = MemoryClient(
|
self.mem0_client = MemoryClient(
|
||||||
api_key=os.getenv("MEM0_API_KEY"),
|
api_key=os.getenv("MEM0_API_KEY"),
|
||||||
org_id=os.getenv("MEM0_ORGANIZATION_ID"),
|
org_id=os.getenv("MEM0_ORGANIZATION_ID"),
|
||||||
project_id=os.getenv("MEM0_PROJECT_ID")
|
project_id=os.getenv("MEM0_PROJECT_ID"),
|
||||||
)
|
)
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.openai_client = OpenAI()
|
self.openai_client = OpenAI()
|
||||||
@@ -42,11 +41,18 @@ class MemorySearch:
|
|||||||
try:
|
try:
|
||||||
if self.is_graph:
|
if self.is_graph:
|
||||||
print("Searching with graph")
|
print("Searching with graph")
|
||||||
memories = self.mem0_client.search(query, user_id=user_id, top_k=self.top_k,
|
memories = self.mem0_client.search(
|
||||||
filter_memories=self.filter_memories, enable_graph=True, output_format='v1.1')
|
query,
|
||||||
|
user_id=user_id,
|
||||||
|
top_k=self.top_k,
|
||||||
|
filter_memories=self.filter_memories,
|
||||||
|
enable_graph=True,
|
||||||
|
output_format="v1.1",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
memories = self.mem0_client.search(query, user_id=user_id, top_k=self.top_k,
|
memories = self.mem0_client.search(
|
||||||
filter_memories=self.filter_memories)
|
query, user_id=user_id, top_k=self.top_k, filter_memories=self.filter_memories
|
||||||
|
)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Retrying...")
|
print("Retrying...")
|
||||||
@@ -57,64 +63,86 @@ class MemorySearch:
|
|||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
if not self.is_graph:
|
if not self.is_graph:
|
||||||
semantic_memories = [{'memory': memory['memory'],
|
semantic_memories = [
|
||||||
'timestamp': memory['metadata']['timestamp'],
|
{
|
||||||
'score': round(memory['score'], 2)}
|
"memory": memory["memory"],
|
||||||
for memory in memories]
|
"timestamp": memory["metadata"]["timestamp"],
|
||||||
|
"score": round(memory["score"], 2),
|
||||||
|
}
|
||||||
|
for memory in memories
|
||||||
|
]
|
||||||
graph_memories = None
|
graph_memories = None
|
||||||
else:
|
else:
|
||||||
semantic_memories = [{'memory': memory['memory'],
|
semantic_memories = [
|
||||||
'timestamp': memory['metadata']['timestamp'],
|
{
|
||||||
'score': round(memory['score'], 2)} for memory in memories['results']]
|
"memory": memory["memory"],
|
||||||
graph_memories = [{"source": relation['source'], "relationship": relation['relationship'], "target": relation['target']} for relation in memories['relations']]
|
"timestamp": memory["metadata"]["timestamp"],
|
||||||
|
"score": round(memory["score"], 2),
|
||||||
|
}
|
||||||
|
for memory in memories["results"]
|
||||||
|
]
|
||||||
|
graph_memories = [
|
||||||
|
{"source": relation["source"], "relationship": relation["relationship"], "target": relation["target"]}
|
||||||
|
for relation in memories["relations"]
|
||||||
|
]
|
||||||
return semantic_memories, graph_memories, end_time - start_time
|
return semantic_memories, graph_memories, end_time - start_time
|
||||||
|
|
||||||
def answer_question(self, speaker_1_user_id, speaker_2_user_id, question, answer, category):
|
def answer_question(self, speaker_1_user_id, speaker_2_user_id, question, answer, category):
|
||||||
speaker_1_memories, speaker_1_graph_memories, speaker_1_memory_time = self.search_memory(speaker_1_user_id, question)
|
speaker_1_memories, speaker_1_graph_memories, speaker_1_memory_time = self.search_memory(
|
||||||
speaker_2_memories, speaker_2_graph_memories, speaker_2_memory_time = self.search_memory(speaker_2_user_id, question)
|
speaker_1_user_id, question
|
||||||
|
)
|
||||||
|
speaker_2_memories, speaker_2_graph_memories, speaker_2_memory_time = self.search_memory(
|
||||||
|
speaker_2_user_id, question
|
||||||
|
)
|
||||||
|
|
||||||
search_1_memory = [f"{item['timestamp']}: {item['memory']}"
|
search_1_memory = [f"{item['timestamp']}: {item['memory']}" for item in speaker_1_memories]
|
||||||
for item in speaker_1_memories]
|
search_2_memory = [f"{item['timestamp']}: {item['memory']}" for item in speaker_2_memories]
|
||||||
search_2_memory = [f"{item['timestamp']}: {item['memory']}"
|
|
||||||
for item in speaker_2_memories]
|
|
||||||
|
|
||||||
template = Template(self.ANSWER_PROMPT)
|
template = Template(self.ANSWER_PROMPT)
|
||||||
answer_prompt = template.render(
|
answer_prompt = template.render(
|
||||||
speaker_1_user_id=speaker_1_user_id.split('_')[0],
|
speaker_1_user_id=speaker_1_user_id.split("_")[0],
|
||||||
speaker_2_user_id=speaker_2_user_id.split('_')[0],
|
speaker_2_user_id=speaker_2_user_id.split("_")[0],
|
||||||
speaker_1_memories=json.dumps(search_1_memory, indent=4),
|
speaker_1_memories=json.dumps(search_1_memory, indent=4),
|
||||||
speaker_2_memories=json.dumps(search_2_memory, indent=4),
|
speaker_2_memories=json.dumps(search_2_memory, indent=4),
|
||||||
speaker_1_graph_memories=json.dumps(speaker_1_graph_memories, indent=4),
|
speaker_1_graph_memories=json.dumps(speaker_1_graph_memories, indent=4),
|
||||||
speaker_2_graph_memories=json.dumps(speaker_2_graph_memories, indent=4),
|
speaker_2_graph_memories=json.dumps(speaker_2_graph_memories, indent=4),
|
||||||
question=question
|
question=question,
|
||||||
)
|
)
|
||||||
|
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
response = self.openai_client.chat.completions.create(
|
response = self.openai_client.chat.completions.create(
|
||||||
model=os.getenv("MODEL"),
|
model=os.getenv("MODEL"), messages=[{"role": "system", "content": answer_prompt}], temperature=0.0
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": answer_prompt}
|
|
||||||
],
|
|
||||||
temperature=0.0
|
|
||||||
)
|
)
|
||||||
t2 = time.time()
|
t2 = time.time()
|
||||||
response_time = t2 - t1
|
response_time = t2 - t1
|
||||||
return response.choices[0].message.content, speaker_1_memories, speaker_2_memories, speaker_1_memory_time, speaker_2_memory_time, speaker_1_graph_memories, speaker_2_graph_memories, response_time
|
return (
|
||||||
|
response.choices[0].message.content,
|
||||||
|
speaker_1_memories,
|
||||||
|
speaker_2_memories,
|
||||||
|
speaker_1_memory_time,
|
||||||
|
speaker_2_memory_time,
|
||||||
|
speaker_1_graph_memories,
|
||||||
|
speaker_2_graph_memories,
|
||||||
|
response_time,
|
||||||
|
)
|
||||||
|
|
||||||
def process_question(self, val, speaker_a_user_id, speaker_b_user_id):
|
def process_question(self, val, speaker_a_user_id, speaker_b_user_id):
|
||||||
question = val.get('question', '')
|
question = val.get("question", "")
|
||||||
answer = val.get('answer', '')
|
answer = val.get("answer", "")
|
||||||
category = val.get('category', -1)
|
category = val.get("category", -1)
|
||||||
evidence = val.get('evidence', [])
|
evidence = val.get("evidence", [])
|
||||||
adversarial_answer = val.get('adversarial_answer', '')
|
adversarial_answer = val.get("adversarial_answer", "")
|
||||||
|
|
||||||
response, speaker_1_memories, speaker_2_memories, speaker_1_memory_time, speaker_2_memory_time, speaker_1_graph_memories, speaker_2_graph_memories, response_time = self.answer_question(
|
(
|
||||||
speaker_a_user_id,
|
response,
|
||||||
speaker_b_user_id,
|
speaker_1_memories,
|
||||||
question,
|
speaker_2_memories,
|
||||||
answer,
|
speaker_1_memory_time,
|
||||||
category
|
speaker_2_memory_time,
|
||||||
)
|
speaker_1_graph_memories,
|
||||||
|
speaker_2_graph_memories,
|
||||||
|
response_time,
|
||||||
|
) = self.answer_question(speaker_a_user_id, speaker_b_user_id, question, answer, category)
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"question": question,
|
"question": question,
|
||||||
@@ -125,67 +153,63 @@ class MemorySearch:
|
|||||||
"adversarial_answer": adversarial_answer,
|
"adversarial_answer": adversarial_answer,
|
||||||
"speaker_1_memories": speaker_1_memories,
|
"speaker_1_memories": speaker_1_memories,
|
||||||
"speaker_2_memories": speaker_2_memories,
|
"speaker_2_memories": speaker_2_memories,
|
||||||
'num_speaker_1_memories': len(speaker_1_memories),
|
"num_speaker_1_memories": len(speaker_1_memories),
|
||||||
'num_speaker_2_memories': len(speaker_2_memories),
|
"num_speaker_2_memories": len(speaker_2_memories),
|
||||||
'speaker_1_memory_time': speaker_1_memory_time,
|
"speaker_1_memory_time": speaker_1_memory_time,
|
||||||
'speaker_2_memory_time': speaker_2_memory_time,
|
"speaker_2_memory_time": speaker_2_memory_time,
|
||||||
"speaker_1_graph_memories": speaker_1_graph_memories,
|
"speaker_1_graph_memories": speaker_1_graph_memories,
|
||||||
"speaker_2_graph_memories": speaker_2_graph_memories,
|
"speaker_2_graph_memories": speaker_2_graph_memories,
|
||||||
"response_time": response_time
|
"response_time": response_time,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Save results after each question is processed
|
# Save results after each question is processed
|
||||||
with open(self.output_path, 'w') as f:
|
with open(self.output_path, "w") as f:
|
||||||
json.dump(self.results, f, indent=4)
|
json.dump(self.results, f, indent=4)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def process_data_file(self, file_path):
|
def process_data_file(self, file_path):
|
||||||
with open(file_path, 'r') as f:
|
with open(file_path, "r") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
for idx, item in tqdm(enumerate(data), total=len(data), desc="Processing conversations"):
|
for idx, item in tqdm(enumerate(data), total=len(data), desc="Processing conversations"):
|
||||||
qa = item['qa']
|
qa = item["qa"]
|
||||||
conversation = item['conversation']
|
conversation = item["conversation"]
|
||||||
speaker_a = conversation['speaker_a']
|
speaker_a = conversation["speaker_a"]
|
||||||
speaker_b = conversation['speaker_b']
|
speaker_b = conversation["speaker_b"]
|
||||||
|
|
||||||
speaker_a_user_id = f"{speaker_a}_{idx}"
|
speaker_a_user_id = f"{speaker_a}_{idx}"
|
||||||
speaker_b_user_id = f"{speaker_b}_{idx}"
|
speaker_b_user_id = f"{speaker_b}_{idx}"
|
||||||
|
|
||||||
for question_item in tqdm(qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False):
|
for question_item in tqdm(
|
||||||
result = self.process_question(
|
qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False
|
||||||
question_item,
|
):
|
||||||
speaker_a_user_id,
|
result = self.process_question(question_item, speaker_a_user_id, speaker_b_user_id)
|
||||||
speaker_b_user_id
|
|
||||||
)
|
|
||||||
self.results[idx].append(result)
|
self.results[idx].append(result)
|
||||||
|
|
||||||
# Save results after each question is processed
|
# Save results after each question is processed
|
||||||
with open(self.output_path, 'w') as f:
|
with open(self.output_path, "w") as f:
|
||||||
json.dump(self.results, f, indent=4)
|
json.dump(self.results, f, indent=4)
|
||||||
|
|
||||||
# Final save at the end
|
# Final save at the end
|
||||||
with open(self.output_path, 'w') as f:
|
with open(self.output_path, "w") as f:
|
||||||
json.dump(self.results, f, indent=4)
|
json.dump(self.results, f, indent=4)
|
||||||
|
|
||||||
def process_questions_parallel(self, qa_list, speaker_a_user_id, speaker_b_user_id, max_workers=1):
|
def process_questions_parallel(self, qa_list, speaker_a_user_id, speaker_b_user_id, max_workers=1):
|
||||||
def process_single_question(val):
|
def process_single_question(val):
|
||||||
result = self.process_question(val, speaker_a_user_id, speaker_b_user_id)
|
result = self.process_question(val, speaker_a_user_id, speaker_b_user_id)
|
||||||
# Save results after each question is processed
|
# Save results after each question is processed
|
||||||
with open(self.output_path, 'w') as f:
|
with open(self.output_path, "w") as f:
|
||||||
json.dump(self.results, f, indent=4)
|
json.dump(self.results, f, indent=4)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
results = list(tqdm(
|
results = list(
|
||||||
executor.map(process_single_question, qa_list),
|
tqdm(executor.map(process_single_question, qa_list), total=len(qa_list), desc="Answering Questions")
|
||||||
total=len(qa_list),
|
)
|
||||||
desc="Answering Questions"
|
|
||||||
))
|
|
||||||
|
|
||||||
# Final save at the end
|
# Final save at the end
|
||||||
with open(self.output_path, 'w') as f:
|
with open(self.output_path, "w") as f:
|
||||||
json.dump(self.results, f, indent=4)
|
json.dump(self.results, f, indent=4)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|||||||
@@ -59,23 +59,19 @@ class OpenAIPredict:
|
|||||||
self.results = defaultdict(list)
|
self.results = defaultdict(list)
|
||||||
|
|
||||||
def search_memory(self, idx):
|
def search_memory(self, idx):
|
||||||
|
with open(f"memories/{idx}.txt", "r") as file:
|
||||||
with open(f'memories/{idx}.txt', 'r') as file:
|
|
||||||
memories = file.read()
|
memories = file.read()
|
||||||
|
|
||||||
return memories, 0
|
return memories, 0
|
||||||
|
|
||||||
def process_question(self, val, idx):
|
def process_question(self, val, idx):
|
||||||
question = val.get('question', '')
|
question = val.get("question", "")
|
||||||
answer = val.get('answer', '')
|
answer = val.get("answer", "")
|
||||||
category = val.get('category', -1)
|
category = val.get("category", -1)
|
||||||
evidence = val.get('evidence', [])
|
evidence = val.get("evidence", [])
|
||||||
adversarial_answer = val.get('adversarial_answer', '')
|
adversarial_answer = val.get("adversarial_answer", "")
|
||||||
|
|
||||||
response, search_memory_time, response_time, context = self.answer_question(
|
response, search_memory_time, response_time, context = self.answer_question(idx, question)
|
||||||
idx,
|
|
||||||
question
|
|
||||||
)
|
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"question": question,
|
"question": question,
|
||||||
@@ -86,7 +82,7 @@ class OpenAIPredict:
|
|||||||
"adversarial_answer": adversarial_answer,
|
"adversarial_answer": adversarial_answer,
|
||||||
"search_memory_time": search_memory_time,
|
"search_memory_time": search_memory_time,
|
||||||
"response_time": response_time,
|
"response_time": response_time,
|
||||||
"context": context
|
"context": context,
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -95,43 +91,35 @@ class OpenAIPredict:
|
|||||||
memories, search_memory_time = self.search_memory(idx)
|
memories, search_memory_time = self.search_memory(idx)
|
||||||
|
|
||||||
template = Template(ANSWER_PROMPT)
|
template = Template(ANSWER_PROMPT)
|
||||||
answer_prompt = template.render(
|
answer_prompt = template.render(memories=memories, question=question)
|
||||||
memories=memories,
|
|
||||||
question=question
|
|
||||||
)
|
|
||||||
|
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
response = self.openai_client.chat.completions.create(
|
response = self.openai_client.chat.completions.create(
|
||||||
model=os.getenv("MODEL"),
|
model=os.getenv("MODEL"), messages=[{"role": "system", "content": answer_prompt}], temperature=0.0
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": answer_prompt}
|
|
||||||
],
|
|
||||||
temperature=0.0
|
|
||||||
)
|
)
|
||||||
t2 = time.time()
|
t2 = time.time()
|
||||||
response_time = t2 - t1
|
response_time = t2 - t1
|
||||||
return response.choices[0].message.content, search_memory_time, response_time, memories
|
return response.choices[0].message.content, search_memory_time, response_time, memories
|
||||||
|
|
||||||
def process_data_file(self, file_path, output_file_path):
|
def process_data_file(self, file_path, output_file_path):
|
||||||
with open(file_path, 'r') as f:
|
with open(file_path, "r") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
for idx, item in tqdm(enumerate(data), total=len(data), desc="Processing conversations"):
|
for idx, item in tqdm(enumerate(data), total=len(data), desc="Processing conversations"):
|
||||||
qa = item['qa']
|
qa = item["qa"]
|
||||||
|
|
||||||
for question_item in tqdm(qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False):
|
for question_item in tqdm(
|
||||||
result = self.process_question(
|
qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False
|
||||||
question_item,
|
):
|
||||||
idx
|
result = self.process_question(question_item, idx)
|
||||||
)
|
|
||||||
self.results[idx].append(result)
|
self.results[idx].append(result)
|
||||||
|
|
||||||
# Save results after each question is processed
|
# Save results after each question is processed
|
||||||
with open(output_file_path, 'w') as f:
|
with open(output_file_path, "w") as f:
|
||||||
json.dump(self.results, f, indent=4)
|
json.dump(self.results, f, indent=4)
|
||||||
|
|
||||||
# Final save at the end
|
# Final save at the end
|
||||||
with open(output_file_path, 'w') as f:
|
with open(output_file_path, "w") as f:
|
||||||
json.dump(self.results, f, indent=4)
|
json.dump(self.results, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
@@ -141,4 +129,3 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
openai_predict = OpenAIPredict()
|
openai_predict = OpenAIPredict()
|
||||||
openai_predict.process_data_file("../../dataset/locomo10.json", args.output_file_path)
|
openai_predict.process_data_file("../../dataset/locomo10.json", args.output_file_path)
|
||||||
|
|
||||||
|
|||||||
@@ -33,10 +33,7 @@ class RAGManager:
|
|||||||
|
|
||||||
def generate_response(self, question, context):
|
def generate_response(self, question, context):
|
||||||
template = Template(PROMPT)
|
template = Template(PROMPT)
|
||||||
prompt = template.render(
|
prompt = template.render(CONTEXT=context, QUESTION=question)
|
||||||
CONTEXT=context,
|
|
||||||
QUESTION=question
|
|
||||||
)
|
|
||||||
|
|
||||||
max_retries = 3
|
max_retries = 3
|
||||||
retries = 0
|
retries = 0
|
||||||
@@ -47,19 +44,21 @@ class RAGManager:
|
|||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system",
|
{
|
||||||
"content": "You are a helpful assistant that can answer "
|
"role": "system",
|
||||||
"questions based on the provided context."
|
"content": "You are a helpful assistant that can answer "
|
||||||
"If the question involves timing, use the conversation date for reference."
|
"questions based on the provided context."
|
||||||
"Provide the shortest possible answer."
|
"If the question involves timing, use the conversation date for reference."
|
||||||
"Use words directly from the conversation when possible."
|
"Provide the shortest possible answer."
|
||||||
"Avoid using subjects in your answer."},
|
"Use words directly from the conversation when possible."
|
||||||
{"role": "user", "content": prompt}
|
"Avoid using subjects in your answer.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
],
|
],
|
||||||
temperature=0
|
temperature=0,
|
||||||
)
|
)
|
||||||
t2 = time.time()
|
t2 = time.time()
|
||||||
return response.choices[0].message.content.strip(), t2-t1
|
return response.choices[0].message.content.strip(), t2 - t1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
retries += 1
|
retries += 1
|
||||||
if retries > max_retries:
|
if retries > max_retries:
|
||||||
@@ -69,21 +68,16 @@ class RAGManager:
|
|||||||
def clean_chat_history(self, chat_history):
|
def clean_chat_history(self, chat_history):
|
||||||
cleaned_chat_history = ""
|
cleaned_chat_history = ""
|
||||||
for c in chat_history:
|
for c in chat_history:
|
||||||
cleaned_chat_history += (f"{c['timestamp']} | {c['speaker']}: "
|
cleaned_chat_history += f"{c['timestamp']} | {c['speaker']}: " f"{c['text']}\n"
|
||||||
f"{c['text']}\n")
|
|
||||||
|
|
||||||
return cleaned_chat_history
|
return cleaned_chat_history
|
||||||
|
|
||||||
def calculate_embedding(self, document):
|
def calculate_embedding(self, document):
|
||||||
response = self.client.embeddings.create(
|
response = self.client.embeddings.create(model=os.getenv("EMBEDDING_MODEL"), input=document)
|
||||||
model=os.getenv("EMBEDDING_MODEL"),
|
|
||||||
input=document
|
|
||||||
)
|
|
||||||
return response.data[0].embedding
|
return response.data[0].embedding
|
||||||
|
|
||||||
def calculate_similarity(self, embedding1, embedding2):
|
def calculate_similarity(self, embedding1, embedding2):
|
||||||
return np.dot(embedding1, embedding2) / (
|
return np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
|
||||||
np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
|
|
||||||
|
|
||||||
def search(self, query, chunks, embeddings, k=1):
|
def search(self, query, chunks, embeddings, k=1):
|
||||||
"""
|
"""
|
||||||
@@ -101,10 +95,7 @@ class RAGManager:
|
|||||||
"""
|
"""
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
query_embedding = self.calculate_embedding(query)
|
query_embedding = self.calculate_embedding(query)
|
||||||
similarities = [
|
similarities = [self.calculate_similarity(query_embedding, embedding) for embedding in embeddings]
|
||||||
self.calculate_similarity(query_embedding, embedding)
|
|
||||||
for embedding in embeddings
|
|
||||||
]
|
|
||||||
|
|
||||||
# Get indices of top-k most similar chunks
|
# Get indices of top-k most similar chunks
|
||||||
if k == 1:
|
if k == 1:
|
||||||
@@ -118,7 +109,7 @@ class RAGManager:
|
|||||||
combined_chunks = "\n<->\n".join([chunks[i] for i in top_indices])
|
combined_chunks = "\n<->\n".join([chunks[i] for i in top_indices])
|
||||||
|
|
||||||
t2 = time.time()
|
t2 = time.time()
|
||||||
return combined_chunks, t2-t1
|
return combined_chunks, t2 - t1
|
||||||
|
|
||||||
def create_chunks(self, chat_history, chunk_size=500):
|
def create_chunks(self, chat_history, chunk_size=500):
|
||||||
"""
|
"""
|
||||||
@@ -139,7 +130,7 @@ class RAGManager:
|
|||||||
|
|
||||||
# Split into chunks based on token count
|
# Split into chunks based on token count
|
||||||
for i in range(0, len(tokens), chunk_size):
|
for i in range(0, len(tokens), chunk_size):
|
||||||
chunk_tokens = tokens[i:i+chunk_size]
|
chunk_tokens = tokens[i : i + chunk_size]
|
||||||
chunk = encoding.decode(chunk_tokens)
|
chunk = encoding.decode(chunk_tokens)
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
|
|
||||||
@@ -159,13 +150,9 @@ class RAGManager:
|
|||||||
chat_history = value["conversation"]
|
chat_history = value["conversation"]
|
||||||
questions = value["question"]
|
questions = value["question"]
|
||||||
|
|
||||||
chunks, embeddings = self.create_chunks(
|
chunks, embeddings = self.create_chunks(chat_history, self.chunk_size)
|
||||||
chat_history, self.chunk_size
|
|
||||||
)
|
|
||||||
|
|
||||||
for item in tqdm(
|
for item in tqdm(questions, desc="Answering questions", leave=False):
|
||||||
questions, desc="Answering questions", leave=False
|
|
||||||
):
|
|
||||||
question = item["question"]
|
question = item["question"]
|
||||||
answer = item.get("answer", "")
|
answer = item.get("answer", "")
|
||||||
category = item["category"]
|
category = item["category"]
|
||||||
@@ -174,22 +161,20 @@ class RAGManager:
|
|||||||
context = chunks[0]
|
context = chunks[0]
|
||||||
search_time = 0
|
search_time = 0
|
||||||
else:
|
else:
|
||||||
context, search_time = self.search(
|
context, search_time = self.search(question, chunks, embeddings, k=self.k)
|
||||||
question, chunks, embeddings, k=self.k
|
response, response_time = self.generate_response(question, context)
|
||||||
)
|
|
||||||
response, response_time = self.generate_response(
|
|
||||||
question, context
|
|
||||||
)
|
|
||||||
|
|
||||||
FINAL_RESULTS[key].append({
|
FINAL_RESULTS[key].append(
|
||||||
"question": question,
|
{
|
||||||
"answer": answer,
|
"question": question,
|
||||||
"category": category,
|
"answer": answer,
|
||||||
"context": context,
|
"category": category,
|
||||||
"response": response,
|
"context": context,
|
||||||
"search_time": search_time,
|
"response": response,
|
||||||
"response_time": response_time,
|
"search_time": search_time,
|
||||||
})
|
"response_time": response_time,
|
||||||
|
}
|
||||||
|
)
|
||||||
with open(output_file_path, "w+") as f:
|
with open(output_file_path, "w+") as f:
|
||||||
json.dump(FINAL_RESULTS, f, indent=4)
|
json.dump(FINAL_RESULTS, f, indent=4)
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,3 @@
|
|||||||
TECHNIQUES = [
|
TECHNIQUES = ["mem0", "rag", "langmem", "zep", "openai"]
|
||||||
"mem0",
|
|
||||||
"rag",
|
|
||||||
"langmem",
|
|
||||||
"zep",
|
|
||||||
"openai"
|
|
||||||
]
|
|
||||||
|
|
||||||
METHODS = [
|
METHODS = ["add", "search"]
|
||||||
"add",
|
|
||||||
"search"
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -19,12 +19,12 @@ class ZepAdd:
|
|||||||
self.load_data()
|
self.load_data()
|
||||||
|
|
||||||
def load_data(self):
|
def load_data(self):
|
||||||
with open(self.data_path, 'r') as f:
|
with open(self.data_path, "r") as f:
|
||||||
self.data = json.load(f)
|
self.data = json.load(f)
|
||||||
return self.data
|
return self.data
|
||||||
|
|
||||||
def process_conversation(self, run_id, item, idx):
|
def process_conversation(self, run_id, item, idx):
|
||||||
conversation = item['conversation']
|
conversation = item["conversation"]
|
||||||
|
|
||||||
user_id = f"run_id_{run_id}_experiment_user_{idx}"
|
user_id = f"run_id_{run_id}_experiment_user_{idx}"
|
||||||
session_id = f"run_id_{run_id}_experiment_session_{idx}"
|
session_id = f"run_id_{run_id}_experiment_session_{idx}"
|
||||||
@@ -41,7 +41,7 @@ class ZepAdd:
|
|||||||
|
|
||||||
print("Starting to add memories... for user", user_id)
|
print("Starting to add memories... for user", user_id)
|
||||||
for key in tqdm(conversation.keys(), desc=f"Processing user {user_id}"):
|
for key in tqdm(conversation.keys(), desc=f"Processing user {user_id}"):
|
||||||
if key in ['speaker_a', 'speaker_b'] or "date" in key:
|
if key in ["speaker_a", "speaker_b"] or "date" in key:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
date_time_key = key + "_date_time"
|
date_time_key = key + "_date_time"
|
||||||
@@ -51,11 +51,13 @@ class ZepAdd:
|
|||||||
for chat in tqdm(chats, desc=f"Adding chats for {key}", leave=False):
|
for chat in tqdm(chats, desc=f"Adding chats for {key}", leave=False):
|
||||||
self.zep_client.memory.add(
|
self.zep_client.memory.add(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
messages=[Message(
|
messages=[
|
||||||
role=chat['speaker'],
|
Message(
|
||||||
role_type="user",
|
role=chat["speaker"],
|
||||||
content=f"{timestamp}: {chat['text']}",
|
role_type="user",
|
||||||
)]
|
content=f"{timestamp}: {chat['text']}",
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_all_conversations(self, run_id):
|
def process_all_conversations(self, run_id):
|
||||||
@@ -71,4 +73,4 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--run_id", type=str, required=True)
|
parser.add_argument("--run_id", type=str, required=True)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
zep_add = ZepAdd(data_path="../../dataset/locomo10.json")
|
zep_add = ZepAdd(data_path="../../dataset/locomo10.json")
|
||||||
zep_add.process_all_conversations(args.run_id)
|
zep_add.process_all_conversations(args.run_id)
|
||||||
|
|||||||
@@ -42,9 +42,9 @@ class ZepSearch:
|
|||||||
return f"{edge.valid_at if edge.valid_at else 'date unknown'} - {(edge.invalid_at if edge.invalid_at else 'present')}"
|
return f"{edge.valid_at if edge.valid_at else 'date unknown'} - {(edge.invalid_at if edge.invalid_at else 'present')}"
|
||||||
|
|
||||||
def compose_search_context(self, edges: list[EntityEdge], nodes: list[EntityNode]) -> str:
|
def compose_search_context(self, edges: list[EntityEdge], nodes: list[EntityNode]) -> str:
|
||||||
facts = [f' - {edge.fact} ({self.format_edge_date_range(edge)})' for edge in edges]
|
facts = [f" - {edge.fact} ({self.format_edge_date_range(edge)})" for edge in edges]
|
||||||
entities = [f' - {node.name}: {node.summary}' for node in nodes]
|
entities = [f" - {node.name}: {node.summary}" for node in nodes]
|
||||||
return TEMPLATE.format(facts='\n'.join(facts), entities='\n'.join(entities))
|
return TEMPLATE.format(facts="\n".join(facts), entities="\n".join(entities))
|
||||||
|
|
||||||
def search_memory(self, run_id, idx, query, max_retries=3, retry_delay=1):
|
def search_memory(self, run_id, idx, query, max_retries=3, retry_delay=1):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -52,8 +52,14 @@ class ZepSearch:
|
|||||||
while retries < max_retries:
|
while retries < max_retries:
|
||||||
try:
|
try:
|
||||||
user_id = f"run_id_{run_id}_experiment_user_{idx}"
|
user_id = f"run_id_{run_id}_experiment_user_{idx}"
|
||||||
edges_results = (self.zep_client.graph.search(user_id=user_id, reranker='cross_encoder', query=query, scope='edges', limit=20)).edges
|
edges_results = (
|
||||||
node_results = (self.zep_client.graph.search(user_id=user_id, reranker='rrf', query=query, scope='nodes', limit=20)).nodes
|
self.zep_client.graph.search(
|
||||||
|
user_id=user_id, reranker="cross_encoder", query=query, scope="edges", limit=20
|
||||||
|
)
|
||||||
|
).edges
|
||||||
|
node_results = (
|
||||||
|
self.zep_client.graph.search(user_id=user_id, reranker="rrf", query=query, scope="nodes", limit=20)
|
||||||
|
).nodes
|
||||||
context = self.compose_search_context(edges_results, node_results)
|
context = self.compose_search_context(edges_results, node_results)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -68,17 +74,13 @@ class ZepSearch:
|
|||||||
return context, end_time - start_time
|
return context, end_time - start_time
|
||||||
|
|
||||||
def process_question(self, run_id, val, idx):
|
def process_question(self, run_id, val, idx):
|
||||||
question = val.get('question', '')
|
question = val.get("question", "")
|
||||||
answer = val.get('answer', '')
|
answer = val.get("answer", "")
|
||||||
category = val.get('category', -1)
|
category = val.get("category", -1)
|
||||||
evidence = val.get('evidence', [])
|
evidence = val.get("evidence", [])
|
||||||
adversarial_answer = val.get('adversarial_answer', '')
|
adversarial_answer = val.get("adversarial_answer", "")
|
||||||
|
|
||||||
response, search_memory_time, response_time, context = self.answer_question(
|
response, search_memory_time, response_time, context = self.answer_question(run_id, idx, question)
|
||||||
run_id,
|
|
||||||
idx,
|
|
||||||
question
|
|
||||||
)
|
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"question": question,
|
"question": question,
|
||||||
@@ -89,7 +91,7 @@ class ZepSearch:
|
|||||||
"adversarial_answer": adversarial_answer,
|
"adversarial_answer": adversarial_answer,
|
||||||
"search_memory_time": search_memory_time,
|
"search_memory_time": search_memory_time,
|
||||||
"response_time": response_time,
|
"response_time": response_time,
|
||||||
"context": context
|
"context": context,
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -98,44 +100,35 @@ class ZepSearch:
|
|||||||
context, search_memory_time = self.search_memory(run_id, idx, question)
|
context, search_memory_time = self.search_memory(run_id, idx, question)
|
||||||
|
|
||||||
template = Template(ANSWER_PROMPT_ZEP)
|
template = Template(ANSWER_PROMPT_ZEP)
|
||||||
answer_prompt = template.render(
|
answer_prompt = template.render(memories=context, question=question)
|
||||||
memories=context,
|
|
||||||
question=question
|
|
||||||
)
|
|
||||||
|
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
response = self.openai_client.chat.completions.create(
|
response = self.openai_client.chat.completions.create(
|
||||||
model=os.getenv("MODEL"),
|
model=os.getenv("MODEL"), messages=[{"role": "system", "content": answer_prompt}], temperature=0.0
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": answer_prompt}
|
|
||||||
],
|
|
||||||
temperature=0.0
|
|
||||||
)
|
)
|
||||||
t2 = time.time()
|
t2 = time.time()
|
||||||
response_time = t2 - t1
|
response_time = t2 - t1
|
||||||
return response.choices[0].message.content, search_memory_time, response_time, context
|
return response.choices[0].message.content, search_memory_time, response_time, context
|
||||||
|
|
||||||
def process_data_file(self, file_path, run_id, output_file_path):
|
def process_data_file(self, file_path, run_id, output_file_path):
|
||||||
with open(file_path, 'r') as f:
|
with open(file_path, "r") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
for idx, item in tqdm(enumerate(data), total=len(data), desc="Processing conversations"):
|
for idx, item in tqdm(enumerate(data), total=len(data), desc="Processing conversations"):
|
||||||
qa = item['qa']
|
qa = item["qa"]
|
||||||
|
|
||||||
for question_item in tqdm(qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False):
|
for question_item in tqdm(
|
||||||
result = self.process_question(
|
qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False
|
||||||
run_id,
|
):
|
||||||
question_item,
|
result = self.process_question(run_id, question_item, idx)
|
||||||
idx
|
|
||||||
)
|
|
||||||
self.results[idx].append(result)
|
self.results[idx].append(result)
|
||||||
|
|
||||||
# Save results after each question is processed
|
# Save results after each question is processed
|
||||||
with open(output_file_path, 'w') as f:
|
with open(output_file_path, "w") as f:
|
||||||
json.dump(self.results, f, indent=4)
|
json.dump(self.results, f, indent=4)
|
||||||
|
|
||||||
# Final save at the end
|
# Final save at the end
|
||||||
with open(output_file_path, 'w') as f:
|
with open(output_file_path, "w") as f:
|
||||||
json.dump(self.results, f, indent=4)
|
json.dump(self.results, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -56,9 +56,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"\n",
|
"\n",
|
||||||
"os.environ[\"OPENAI_API_KEY\"] = (\n",
|
"os.environ[\"OPENAI_API_KEY\"] = \"\""
|
||||||
" \"\"\n",
|
|
||||||
")"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -149,7 +147,7 @@
|
|||||||
" \"role\": \"assistant\",\n",
|
" \"role\": \"assistant\",\n",
|
||||||
" \"content\": \"Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future.\",\n",
|
" \"content\": \"Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future.\",\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
"]\n"
|
"]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -166,9 +164,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Store inferred memories (default behavior)\n",
|
"# Store inferred memories (default behavior)\n",
|
||||||
"result = m.add(\n",
|
"result = m.add(messages, user_id=\"alice\", metadata={\"category\": \"movie_recommendations\"})"
|
||||||
" messages, user_id=\"alice\", metadata={\"category\": \"movie_recommendations\"}\n",
|
|
||||||
")"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -20,19 +20,19 @@ agent = Agent(
|
|||||||
name="Fitness Agent",
|
name="Fitness Agent",
|
||||||
model=OpenAIChat(id="gpt-4o"),
|
model=OpenAIChat(id="gpt-4o"),
|
||||||
description="You are a helpful fitness assistant who remembers past logs and gives personalized suggestions for Anish's training and diet.",
|
description="You are a helpful fitness assistant who remembers past logs and gives personalized suggestions for Anish's training and diet.",
|
||||||
markdown=True
|
markdown=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Store user preferences as memory
|
# Store user preferences as memory
|
||||||
def store_user_preferences(conversation: list, user_id: str = USER_ID):
|
def store_user_preferences(conversation: list, user_id: str = USER_ID):
|
||||||
"""Store user preferences from conversation history"""
|
"""Store user preferences from conversation history"""
|
||||||
memory_client.add(conversation, user_id=user_id, output_format='v1.1')
|
memory_client.add(conversation, user_id=user_id, output_format="v1.1")
|
||||||
|
|
||||||
|
|
||||||
# Memory-aware assistant function
|
# Memory-aware assistant function
|
||||||
def fitness_coach(user_input: str, user_id: str = USER_ID):
|
def fitness_coach(user_input: str, user_id: str = USER_ID):
|
||||||
memories = memory_client.search(user_input, user_id=user_id) # Search relevant memories bases on user query
|
memories = memory_client.search(user_input, user_id=user_id) # Search relevant memories bases on user query
|
||||||
memory_context = "\n".join(f"- {m['memory']}" for m in memories)
|
memory_context = "\n".join(f"- {m['memory']}" for m in memories)
|
||||||
|
|
||||||
prompt = f"""You are a fitness assistant who helps Anish with his training, recovery, and diet. You have long-term memory of his health, routines, preferences, and past conversations.
|
prompt = f"""You are a fitness assistant who helps Anish with his training, recovery, and diet. You have long-term memory of his health, routines, preferences, and past conversations.
|
||||||
@@ -48,113 +48,66 @@ User query:
|
|||||||
memory_client.add(f"User: {user_input}\nAssistant: {response.content}", user_id=user_id)
|
memory_client.add(f"User: {user_input}\nAssistant: {response.content}", user_id=user_id)
|
||||||
return response.content
|
return response.content
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------------
|
# --------------------------------------------------
|
||||||
# Store user preferences and memories
|
# Store user preferences and memories
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "Hi, I’m Anish. I'm 26 years old, 5'10\", and weigh 72kg. I started working out 6 months ago with the goal of building lean muscle."
|
"content": "Hi, I’m Anish. I'm 26 years old, 5'10\", and weigh 72kg. I started working out 6 months ago with the goal of building lean muscle.",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": "Got it — you're 26, 5'10\", 72kg, and on a lean muscle journey. Started gym 6 months ago."
|
"content": "Got it — you're 26, 5'10\", 72kg, and on a lean muscle journey. Started gym 6 months ago.",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "I follow a push-pull-legs routine and train 5 times a week. My rest days are Wednesday and Sunday."
|
"content": "I follow a push-pull-legs routine and train 5 times a week. My rest days are Wednesday and Sunday.",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": "Understood — push-pull-legs split, training 5x/week with rest on Wednesdays and Sundays."
|
"content": "Understood — push-pull-legs split, training 5x/week with rest on Wednesdays and Sundays.",
|
||||||
},
|
},
|
||||||
|
{"role": "user", "content": "After push days, I usually eat high-protein and moderate-carb meals to recover."},
|
||||||
|
{"role": "assistant", "content": "Noted — high-protein, moderate-carb meals after push workouts."},
|
||||||
|
{"role": "user", "content": "For pull days, I take whey protein and eat a banana after training."},
|
||||||
|
{"role": "assistant", "content": "Logged — whey protein and banana post pull workouts."},
|
||||||
|
{"role": "user", "content": "On leg days, I make sure to have complex carbs like rice or oats."},
|
||||||
|
{"role": "assistant", "content": "Noted — complex carbs like rice and oats are part of your leg day meals."},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "After push days, I usually eat high-protein and moderate-carb meals to recover."
|
"content": "I often feel sore after leg days, so I use turmeric milk and magnesium to help with recovery.",
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "Noted — high-protein, moderate-carb meals after push workouts."
|
|
||||||
},
|
},
|
||||||
|
{"role": "assistant", "content": "I'll remember turmeric milk and magnesium as part of your leg day recovery."},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "For pull days, I take whey protein and eat a banana after training."
|
"content": "Last push day, I did 3x8 bench press at 60kg, 4x12 overhead press, and dips. Felt fatigued after.",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": "Logged — whey protein and banana post pull workouts."
|
"content": "Push day logged — 60kg bench, overhead press, dips. You felt fatigued afterward.",
|
||||||
},
|
},
|
||||||
|
{"role": "user", "content": "I prefer light dinners post-workout like tofu, soup, and vegetables."},
|
||||||
|
{"role": "assistant", "content": "Got it — light dinners post-workout: tofu, soup, and veggies."},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "On leg days, I make sure to have complex carbs like rice or oats."
|
"content": "I have mild lactose intolerance, so I avoid dairy. I use almond milk or lactose-free whey.",
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "Noted — complex carbs like rice and oats are part of your leg day meals."
|
|
||||||
},
|
},
|
||||||
|
{"role": "assistant", "content": "Understood — avoiding regular dairy, using almond milk and lactose-free whey."},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "I often feel sore after leg days, so I use turmeric milk and magnesium to help with recovery."
|
"content": "I get occasional knee pain, so I avoid deep squats and do more hamstring curls and glute bridges on leg days.",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": "I'll remember turmeric milk and magnesium as part of your leg day recovery."
|
"content": "Noted — due to knee discomfort, you substitute deep squats with curls and glute bridges.",
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "Last push day, I did 3x8 bench press at 60kg, 4x12 overhead press, and dips. Felt fatigued after."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "Push day logged — 60kg bench, overhead press, dips. You felt fatigued afterward."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "I prefer light dinners post-workout like tofu, soup, and vegetables."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "Got it — light dinners post-workout: tofu, soup, and veggies."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "I have mild lactose intolerance, so I avoid dairy. I use almond milk or lactose-free whey."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "Understood — avoiding regular dairy, using almond milk and lactose-free whey."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "I get occasional knee pain, so I avoid deep squats and do more hamstring curls and glute bridges on leg days."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "Noted — due to knee discomfort, you substitute deep squats with curls and glute bridges."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "I track sleep and notice poor performance when I sleep less than 6 hours."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "Logged — performance drops when you get under 6 hours of sleep."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "I take magnesium supplements to help with muscle recovery and sleep quality."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "Remembered — magnesium helps you with recovery and sleep."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "I avoid caffeine after 4 PM because it affects my sleep."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "Got it — you avoid caffeine post-4 PM to protect your sleep."
|
|
||||||
},
|
},
|
||||||
|
{"role": "user", "content": "I track sleep and notice poor performance when I sleep less than 6 hours."},
|
||||||
|
{"role": "assistant", "content": "Logged — performance drops when you get under 6 hours of sleep."},
|
||||||
|
{"role": "user", "content": "I take magnesium supplements to help with muscle recovery and sleep quality."},
|
||||||
|
{"role": "assistant", "content": "Remembered — magnesium helps you with recovery and sleep."},
|
||||||
|
{"role": "user", "content": "I avoid caffeine after 4 PM because it affects my sleep."},
|
||||||
|
{"role": "assistant", "content": "Got it — you avoid caffeine post-4 PM to protect your sleep."},
|
||||||
]
|
]
|
||||||
store_user_preferences(messages)
|
store_user_preferences(messages)
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from google.adk.agents import Agent
|
from google.adk.agents import Agent
|
||||||
from google.adk.sessions import InMemorySessionService
|
|
||||||
from google.adk.runners import Runner
|
from google.adk.runners import Runner
|
||||||
|
from google.adk.sessions import InMemorySessionService
|
||||||
from google.genai import types
|
from google.genai import types
|
||||||
|
|
||||||
from mem0 import MemoryClient
|
from mem0 import MemoryClient
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||||
@@ -19,14 +21,14 @@ def save_patient_info(information: str) -> dict:
|
|||||||
print(f"Storing patient information: {information[:30]}...")
|
print(f"Storing patient information: {information[:30]}...")
|
||||||
|
|
||||||
# Get user_id from session state or use default
|
# Get user_id from session state or use default
|
||||||
user_id = getattr(save_patient_info, 'user_id', 'default_user')
|
user_id = getattr(save_patient_info, "user_id", "default_user")
|
||||||
|
|
||||||
# Store in Mem0
|
# Store in Mem0
|
||||||
response = mem0_client.add(
|
mem0_client.add(
|
||||||
[{"role": "user", "content": information}],
|
[{"role": "user", "content": information}],
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
run_id="healthcare_session",
|
run_id="healthcare_session",
|
||||||
metadata={"type": "patient_information"}
|
metadata={"type": "patient_information"},
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"status": "success", "message": "Information saved"}
|
return {"status": "success", "message": "Information saved"}
|
||||||
@@ -37,7 +39,7 @@ def retrieve_patient_info(query: str) -> str:
|
|||||||
print(f"Searching for patient information: {query}")
|
print(f"Searching for patient information: {query}")
|
||||||
|
|
||||||
# Get user_id from session state or use default
|
# Get user_id from session state or use default
|
||||||
user_id = getattr(retrieve_patient_info, 'user_id', 'default_user')
|
user_id = getattr(retrieve_patient_info, "user_id", "default_user")
|
||||||
|
|
||||||
# Search Mem0
|
# Search Mem0
|
||||||
results = mem0_client.search(
|
results = mem0_client.search(
|
||||||
@@ -45,7 +47,7 @@ def retrieve_patient_info(query: str) -> str:
|
|||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
run_id="healthcare_session",
|
run_id="healthcare_session",
|
||||||
limit=5,
|
limit=5,
|
||||||
threshold=0.7 # Higher threshold for more relevant results
|
threshold=0.7, # Higher threshold for more relevant results
|
||||||
)
|
)
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
@@ -65,7 +67,7 @@ def schedule_appointment(date: str, time: str, reason: str) -> dict:
|
|||||||
"status": "success",
|
"status": "success",
|
||||||
"appointment_id": appointment_id,
|
"appointment_id": appointment_id,
|
||||||
"confirmation": f"Appointment scheduled for {date} at {time} for {reason}",
|
"confirmation": f"Appointment scheduled for {date} at {time} for {reason}",
|
||||||
"message": "Please arrive 15 minutes early to complete paperwork."
|
"message": "Please arrive 15 minutes early to complete paperwork.",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -89,7 +91,7 @@ IMPORTANT GUIDELINES:
|
|||||||
- For serious symptoms, always recommend consulting a healthcare professional.
|
- For serious symptoms, always recommend consulting a healthcare professional.
|
||||||
- Keep all patient information confidential.
|
- Keep all patient information confidential.
|
||||||
""",
|
""",
|
||||||
tools=[save_patient_info, retrieve_patient_info, schedule_appointment]
|
tools=[save_patient_info, retrieve_patient_info, schedule_appointment],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set Up Session and Runner
|
# Set Up Session and Runner
|
||||||
@@ -101,18 +103,10 @@ USER_ID = "Alex"
|
|||||||
SESSION_ID = "session_001"
|
SESSION_ID = "session_001"
|
||||||
|
|
||||||
# Create a session
|
# Create a session
|
||||||
session = session_service.create_session(
|
session = session_service.create_session(app_name=APP_NAME, user_id=USER_ID, session_id=SESSION_ID)
|
||||||
app_name=APP_NAME,
|
|
||||||
user_id=USER_ID,
|
|
||||||
session_id=SESSION_ID
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create the runner
|
# Create the runner
|
||||||
runner = Runner(
|
runner = Runner(agent=healthcare_agent, app_name=APP_NAME, session_service=session_service)
|
||||||
agent=healthcare_agent,
|
|
||||||
app_name=APP_NAME,
|
|
||||||
session_service=session_service
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Interact with the Healthcare Assistant
|
# Interact with the Healthcare Assistant
|
||||||
@@ -121,21 +115,14 @@ async def call_agent_async(query, runner, user_id, session_id):
|
|||||||
print(f"\n>>> Patient: {query}")
|
print(f"\n>>> Patient: {query}")
|
||||||
|
|
||||||
# Format the user's message
|
# Format the user's message
|
||||||
content = types.Content(
|
content = types.Content(role="user", parts=[types.Part(text=query)])
|
||||||
role='user',
|
|
||||||
parts=[types.Part(text=query)]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set user_id for tools to access
|
# Set user_id for tools to access
|
||||||
save_patient_info.user_id = user_id
|
save_patient_info.user_id = user_id
|
||||||
retrieve_patient_info.user_id = user_id
|
retrieve_patient_info.user_id = user_id
|
||||||
|
|
||||||
# Run the agent
|
# Run the agent
|
||||||
async for event in runner.run_async(
|
async for event in runner.run_async(user_id=user_id, session_id=session_id, new_message=content):
|
||||||
user_id=user_id,
|
|
||||||
session_id=session_id,
|
|
||||||
new_message=content
|
|
||||||
):
|
|
||||||
if event.is_final_response():
|
if event.is_final_response():
|
||||||
if event.content and event.content.parts:
|
if event.content and event.content.parts:
|
||||||
response = event.content.parts[0].text
|
response = event.content.parts[0].text
|
||||||
@@ -152,7 +139,7 @@ async def run_conversation():
|
|||||||
"Hi, I'm Alex. I've been having headaches for the past week, and I have a penicillin allergy.",
|
"Hi, I'm Alex. I've been having headaches for the past week, and I have a penicillin allergy.",
|
||||||
runner=runner,
|
runner=runner,
|
||||||
user_id=USER_ID,
|
user_id=USER_ID,
|
||||||
session_id=SESSION_ID
|
session_id=SESSION_ID,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Request for health information
|
# Request for health information
|
||||||
@@ -160,7 +147,7 @@ async def run_conversation():
|
|||||||
"Can you tell me more about what might be causing my headaches?",
|
"Can you tell me more about what might be causing my headaches?",
|
||||||
runner=runner,
|
runner=runner,
|
||||||
user_id=USER_ID,
|
user_id=USER_ID,
|
||||||
session_id=SESSION_ID
|
session_id=SESSION_ID,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Schedule an appointment
|
# Schedule an appointment
|
||||||
@@ -168,15 +155,12 @@ async def run_conversation():
|
|||||||
"I think I should see a doctor. Can you help me schedule an appointment for next Monday at 2pm?",
|
"I think I should see a doctor. Can you help me schedule an appointment for next Monday at 2pm?",
|
||||||
runner=runner,
|
runner=runner,
|
||||||
user_id=USER_ID,
|
user_id=USER_ID,
|
||||||
session_id=SESSION_ID
|
session_id=SESSION_ID,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test memory - should remember patient name, symptoms, and allergy
|
# Test memory - should remember patient name, symptoms, and allergy
|
||||||
await call_agent_async(
|
await call_agent_async(
|
||||||
"What medications should I avoid for my headaches?",
|
"What medications should I avoid for my headaches?", runner=runner, user_id=USER_ID, session_id=SESSION_ID
|
||||||
runner=runner,
|
|
||||||
user_id=USER_ID,
|
|
||||||
session_id=SESSION_ID
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -191,37 +175,28 @@ async def interactive_mode():
|
|||||||
session_id = f"session_{hash(patient_id) % 1000:03d}"
|
session_id = f"session_{hash(patient_id) % 1000:03d}"
|
||||||
|
|
||||||
# Create session for this user
|
# Create session for this user
|
||||||
session = session_service.create_session(
|
session_service.create_session(app_name=APP_NAME, user_id=patient_id, session_id=session_id)
|
||||||
app_name=APP_NAME,
|
|
||||||
user_id=patient_id,
|
|
||||||
session_id=session_id
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"\nStarting conversation with patient ID: {patient_id}")
|
print(f"\nStarting conversation with patient ID: {patient_id}")
|
||||||
print("Type your message and press Enter.")
|
print("Type your message and press Enter.")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
user_input = input("\n>>> Patient: ").strip()
|
user_input = input("\n>>> Patient: ").strip()
|
||||||
if user_input.lower() in ['exit', 'quit', 'bye']:
|
if user_input.lower() in ["exit", "quit", "bye"]:
|
||||||
print("Ending conversation. Thank you!")
|
print("Ending conversation. Thank you!")
|
||||||
break
|
break
|
||||||
|
|
||||||
await call_agent_async(
|
await call_agent_async(user_input, runner=runner, user_id=patient_id, session_id=session_id)
|
||||||
user_input,
|
|
||||||
runner=runner,
|
|
||||||
user_id=patient_id,
|
|
||||||
session_id=session_id
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Main execution
|
# Main execution
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Healthcare Assistant with Memory')
|
parser = argparse.ArgumentParser(description="Healthcare Assistant with Memory")
|
||||||
parser.add_argument('--demo', action='store_true', help='Run the demo conversation')
|
parser.add_argument("--demo", action="store_true", help="Run the demo conversation")
|
||||||
parser.add_argument('--interactive', action='store_true', help='Run in interactive mode')
|
parser.add_argument("--interactive", action="store_true", help="Run in interactive mode")
|
||||||
parser.add_argument('--patient-id', type=str, default=USER_ID, help='Patient ID for the conversation')
|
parser.add_argument("--patient-id", type=str, default=USER_ID, help="Patient ID for the conversation")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.demo:
|
if args.demo:
|
||||||
@@ -231,5 +206,3 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
# Default to demo mode if no arguments provided
|
# Default to demo mode if no arguments provided
|
||||||
asyncio.run(run_conversation())
|
asyncio.run(run_conversation())
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -16,26 +16,21 @@ from mem0 import Memory
|
|||||||
|
|
||||||
# Configure Mem0 with Grok 3 and Qdrant
|
# Configure Mem0 with Grok 3 and Qdrant
|
||||||
config = {
|
config = {
|
||||||
"vector_store": {
|
"vector_store": {"provider": "qdrant", "config": {"embedding_model_dims": 384}},
|
||||||
"provider": "qdrant",
|
|
||||||
"config": {
|
|
||||||
"embedding_model_dims": 384
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"llm": {
|
"llm": {
|
||||||
"provider": "xai",
|
"provider": "xai",
|
||||||
"config": {
|
"config": {
|
||||||
"model": "grok-3-beta",
|
"model": "grok-3-beta",
|
||||||
"temperature": 0.1,
|
"temperature": 0.1,
|
||||||
"max_tokens": 2000,
|
"max_tokens": 2000,
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
"embedder": {
|
"embedder": {
|
||||||
"provider": "huggingface",
|
"provider": "huggingface",
|
||||||
"config": {
|
"config": {
|
||||||
"model": "all-MiniLM-L6-v2" # open embedding model
|
"model": "all-MiniLM-L6-v2" # open embedding model
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Instantiate memory layer
|
# Instantiate memory layer
|
||||||
@@ -57,20 +52,14 @@ def recommend_movie_with_memory(user_id: str, user_query: str):
|
|||||||
prompt += f"\nPreviously, the user mentioned: {past_memories}"
|
prompt += f"\nPreviously, the user mentioned: {past_memories}"
|
||||||
|
|
||||||
# Generate movie recommendation using Grok 3
|
# Generate movie recommendation using Grok 3
|
||||||
response = grok_client.chat.completions.create(
|
response = grok_client.chat.completions.create(model="grok-3-beta", messages=[{"role": "user", "content": prompt}])
|
||||||
model="grok-3-beta",
|
|
||||||
messages=[
|
|
||||||
{"role": "user", "content": prompt}
|
|
||||||
]
|
|
||||||
)
|
|
||||||
recommendation = response.choices[0].message.content
|
recommendation = response.choices[0].message.content
|
||||||
|
|
||||||
# Store conversation in memory
|
# Store conversation in memory
|
||||||
memory.add(
|
memory.add(
|
||||||
[{"role": "user", "content": user_query},
|
[{"role": "user", "content": user_query}, {"role": "assistant", "content": recommendation}],
|
||||||
{"role": "assistant", "content": recommendation}],
|
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
metadata={"category": "movie"}
|
metadata={"category": "movie"},
|
||||||
)
|
)
|
||||||
|
|
||||||
return recommendation
|
return recommendation
|
||||||
@@ -81,10 +70,11 @@ if __name__ == "__main__":
|
|||||||
user_id = "arshi"
|
user_id = "arshi"
|
||||||
recommend_movie_with_memory(user_id, "I'm looking for a movie to watch tonight. Any suggestions?")
|
recommend_movie_with_memory(user_id, "I'm looking for a movie to watch tonight. Any suggestions?")
|
||||||
# OUTPUT: You have watched Intersteller last weekend and you don't like horror movies, maybe you can watch "Purple Hearts" today.
|
# OUTPUT: You have watched Intersteller last weekend and you don't like horror movies, maybe you can watch "Purple Hearts" today.
|
||||||
recommend_movie_with_memory(user_id, "Can we skip the tearjerkers? I really enjoyed Notting Hill and Crazy Rich Asians.")
|
recommend_movie_with_memory(
|
||||||
|
user_id, "Can we skip the tearjerkers? I really enjoyed Notting Hill and Crazy Rich Asians."
|
||||||
|
)
|
||||||
# OUTPUT: Got it — no sad endings! You might enjoy "The Proposal" or "Love, Rosie". They’re both light-hearted romcoms with happy vibes.
|
# OUTPUT: Got it — no sad endings! You might enjoy "The Proposal" or "Love, Rosie". They’re both light-hearted romcoms with happy vibes.
|
||||||
recommend_movie_with_memory(user_id, "Any light-hearted movie I can watch after work today?")
|
recommend_movie_with_memory(user_id, "Any light-hearted movie I can watch after work today?")
|
||||||
# OUTPUT: Since you liked Crazy Rich Asians and The Proposal, how about "The Intern" or "Isn’t It Romantic"? Both are upbeat, funny, and perfect for relaxing.
|
# OUTPUT: Since you liked Crazy Rich Asians and The Proposal, how about "The Intern" or "Isn’t It Romantic"? Both are upbeat, funny, and perfect for relaxing.
|
||||||
recommend_movie_with_memory(user_id, "I’ve already watched The Intern. Something new maybe?")
|
recommend_movie_with_memory(user_id, "I’ve already watched The Intern. Something new maybe?")
|
||||||
# OUTPUT: No problem! Try "Your Place or Mine" - romcoms that match your taste and are tear-free!
|
# OUTPUT: No problem! Try "Your Place or Mine" - romcoms that match your taste and are tear-free!
|
||||||
|
|
||||||
|
|||||||
@@ -23,8 +23,8 @@ agent = Agent(
|
|||||||
name="Personal Agent",
|
name="Personal Agent",
|
||||||
model=OpenAIChat(id="gpt-4o"),
|
model=OpenAIChat(id="gpt-4o"),
|
||||||
description="You are a helpful personal agent that helps me with day to day activities."
|
description="You are a helpful personal agent that helps me with day to day activities."
|
||||||
"You can process both text and images.",
|
"You can process both text and images.",
|
||||||
markdown=True
|
markdown=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -35,24 +35,16 @@ def chat_user(user_input: str = None, user_id: str = "user_123", image_path: str
|
|||||||
base64_image = base64.b64encode(image_file.read()).decode("utf-8")
|
base64_image = base64.b64encode(image_file.read()).decode("utf-8")
|
||||||
|
|
||||||
# First: the text message
|
# First: the text message
|
||||||
text_msg = {
|
text_msg = {"role": "user", "content": user_input}
|
||||||
"role": "user",
|
|
||||||
"content": user_input
|
|
||||||
}
|
|
||||||
|
|
||||||
# Second: the image message
|
# Second: the image message
|
||||||
image_msg = {
|
image_msg = {
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": {
|
"content": {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}},
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"data:image/jpeg;base64,{base64_image}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Send both as separate message objects
|
# Send both as separate message objects
|
||||||
client.add([text_msg, image_msg], user_id=user_id, output_format='v1.1')
|
client.add([text_msg, image_msg], user_id=user_id, output_format="v1.1")
|
||||||
print("✅ Image uploaded and stored in memory.")
|
print("✅ Image uploaded and stored in memory.")
|
||||||
|
|
||||||
if user_input:
|
if user_input:
|
||||||
@@ -92,10 +84,13 @@ print(chat_user("When is my test?", user_id=user_id))
|
|||||||
# OUTPUT: Your pilot's test is on your birthday, which is in five days. You're turning 25!
|
# OUTPUT: Your pilot's test is on your birthday, which is in five days. You're turning 25!
|
||||||
# Good luck with your preparations, and remember to take some time to relax amidst the studying.
|
# Good luck with your preparations, and remember to take some time to relax amidst the studying.
|
||||||
|
|
||||||
print(chat_user("This is the picture of what I brought with me in the trip to Bahamas",
|
print(
|
||||||
image_path="travel_items.jpeg", # this will be added to Mem0 memory
|
chat_user(
|
||||||
user_id=user_id))
|
"This is the picture of what I brought with me in the trip to Bahamas",
|
||||||
print(chat_user("hey can you quickly tell me if brought my sunglasses to my trip, not able to find",
|
image_path="travel_items.jpeg", # this will be added to Mem0 memory
|
||||||
user_id=user_id))
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(chat_user("hey can you quickly tell me if brought my sunglasses to my trip, not able to find", user_id=user_id))
|
||||||
# OUTPUT: Yes, you did bring your sunglasses on your trip to the Bahamas along with your laptop, face masks and other items..
|
# OUTPUT: Yes, you did bring your sunglasses on your trip to the Bahamas along with your laptop, face masks and other items..
|
||||||
# Since you can't find them now, perhaps check the pockets of jackets you wore or in your luggage compartments.
|
# Since you can't find them now, perhaps check the pockets of jackets you wore or in your luggage compartments.
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ In order to run this file, you need to set up your Mem0 API at Mem0 platform and
|
|||||||
export OPENAI_API_KEY="your_openai_api_key"
|
export OPENAI_API_KEY="your_openai_api_key"
|
||||||
export MEM0_API_KEY="your_mem0_api_key"
|
export MEM0_API_KEY="your_mem0_api_key"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from agents import Agent, Runner
|
from agents import Agent, Runner
|
||||||
@@ -23,25 +24,19 @@ study_agent = Agent(
|
|||||||
- Identify topics the user has struggled with (e.g., "I'm confused", "this is hard")
|
- Identify topics the user has struggled with (e.g., "I'm confused", "this is hard")
|
||||||
- Help with spaced repetition by suggesting topics to revisit based on last review time
|
- Help with spaced repetition by suggesting topics to revisit based on last review time
|
||||||
- Personalize answers using stored memories
|
- Personalize answers using stored memories
|
||||||
- Summarize PDFs or notes the user uploads""")
|
- Summarize PDFs or notes the user uploads""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Upload and store PDF to Mem0
|
# Upload and store PDF to Mem0
|
||||||
def upload_pdf(pdf_url: str, user_id: str):
|
def upload_pdf(pdf_url: str, user_id: str):
|
||||||
pdf_message = {
|
pdf_message = {"role": "user", "content": {"type": "pdf_url", "pdf_url": {"url": pdf_url}}}
|
||||||
"role": "user",
|
|
||||||
"content": {
|
|
||||||
"type": "pdf_url",
|
|
||||||
"pdf_url": {"url": pdf_url}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
client.add([pdf_message], user_id=user_id)
|
client.add([pdf_message], user_id=user_id)
|
||||||
print("✅ PDF uploaded and processed into memory.")
|
print("✅ PDF uploaded and processed into memory.")
|
||||||
|
|
||||||
|
|
||||||
# Main interaction loop with your personal study buddy
|
# Main interaction loop with your personal study buddy
|
||||||
async def study_buddy(user_id: str, topic: str, user_input: str):
|
async def study_buddy(user_id: str, topic: str, user_input: str):
|
||||||
|
|
||||||
memories = client.search(f"{topic}", user_id=user_id)
|
memories = client.search(f"{topic}", user_id=user_id)
|
||||||
memory_context = "n".join(f"- {m['memory']}" for m in memories)
|
memory_context = "n".join(f"- {m['memory']}" for m in memories)
|
||||||
|
|
||||||
@@ -56,9 +51,11 @@ Now respond to the user's new question or comment:
|
|||||||
result = await Runner.run(study_agent, prompt)
|
result = await Runner.run(study_agent, prompt)
|
||||||
response = result.final_output
|
response = result.final_output
|
||||||
|
|
||||||
client.add([
|
client.add(
|
||||||
{"role": "user", "content": f'''Topic: {topic}nUser: {user_input}nnStudy Assistant: {response}'''}
|
[{"role": "user", "content": f"""Topic: {topic}nUser: {user_input}nnStudy Assistant: {response}"""}],
|
||||||
], user_id=user_id, metadata={"topic": topic})
|
user_id=user_id,
|
||||||
|
metadata={"topic": topic},
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@@ -78,7 +75,12 @@ async def main():
|
|||||||
|
|
||||||
# Demonstrate spaced repetition prompting
|
# Demonstrate spaced repetition prompting
|
||||||
topic = "Momentum Conservation"
|
topic = "Momentum Conservation"
|
||||||
print(await study_buddy(user_id, topic, "I think we covered this last week. Is it time to review momentum conservation again?"))
|
print(
|
||||||
|
await study_buddy(
|
||||||
|
user_id, topic, "I think we covered this last week. Is it time to review momentum conservation again?"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ def initialize_memory():
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "I prefer brief and concise responses without unnecessary explanations. I get frustrated when assistants are too wordy or repeat information I already know."
|
"content": "I prefer brief and concise responses without unnecessary explanations. I get frustrated when assistants are too wordy or repeat information I already know.",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
@@ -65,7 +65,7 @@ def initialize_memory():
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "I like to listen to jazz music when I'm working, especially artists like Miles Davis and John Coltrane. I find it helps me focus and be more productive."
|
"content": "I like to listen to jazz music when I'm working, especially artists like Miles Davis and John Coltrane. I find it helps me focus and be more productive.",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
@@ -73,7 +73,7 @@ def initialize_memory():
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "I usually wake up at 7 AM and prefer reminders for meetings 30 minutes in advance. My most productive hours are between 9 AM and noon, so I try to schedule important tasks during that time."
|
"content": "I usually wake up at 7 AM and prefer reminders for meetings 30 minutes in advance. My most productive hours are between 9 AM and noon, so I try to schedule important tasks during that time.",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
@@ -81,7 +81,7 @@ def initialize_memory():
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "My favorite color is navy blue, and I prefer dark mode in all my apps. I'm allergic to peanuts, so please remind me to check ingredients when I ask about recipes or restaurants."
|
"content": "My favorite color is navy blue, and I prefer dark mode in all my apps. I'm allergic to peanuts, so please remind me to check ingredients when I ask about recipes or restaurants.",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
@@ -89,7 +89,7 @@ def initialize_memory():
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "My partner's name is Jamie, and we have a golden retriever named Max who is 3 years old. My parents live in Chicago, and I try to visit them once every two months."
|
"content": "My partner's name is Jamie, and we have a golden retriever named Max who is 3 years old. My parents live in Chicago, and I try to visit them once every two months.",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
@@ -135,11 +135,11 @@ def record_audio(filename="input.wav", record_seconds=5):
|
|||||||
stream.close()
|
stream.close()
|
||||||
p.terminate()
|
p.terminate()
|
||||||
|
|
||||||
with wave.open(filename, 'wb') as wf:
|
with wave.open(filename, "wb") as wf:
|
||||||
wf.setnchannels(channels)
|
wf.setnchannels(channels)
|
||||||
wf.setsampwidth(p.get_sample_size(fmt))
|
wf.setsampwidth(p.get_sample_size(fmt))
|
||||||
wf.setframerate(rate)
|
wf.setframerate(rate)
|
||||||
wf.writeframes(b''.join(frames))
|
wf.writeframes(b"".join(frames))
|
||||||
|
|
||||||
|
|
||||||
# ------------------ STT USING WHISPER ------------------
|
# ------------------ STT USING WHISPER ------------------
|
||||||
@@ -147,10 +147,7 @@ def transcribe_whisper(audio_path):
|
|||||||
print("🔎 Transcribing with Whisper...")
|
print("🔎 Transcribing with Whisper...")
|
||||||
try:
|
try:
|
||||||
with open(audio_path, "rb") as audio_file:
|
with open(audio_path, "rb") as audio_file:
|
||||||
transcript = openai_client.audio.transcriptions.create(
|
transcript = openai_client.audio.transcriptions.create(model="whisper-1", file=audio_file)
|
||||||
model="whisper-1",
|
|
||||||
file=audio_file
|
|
||||||
)
|
|
||||||
print(f"🗣️ You said: {transcript.text}")
|
print(f"🗣️ You said: {transcript.text}")
|
||||||
return transcript.text
|
return transcript.text
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -165,9 +162,7 @@ def get_agent_response(user_input):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
task = Task(
|
task = Task(
|
||||||
description=f"Respond to: {user_input}",
|
description=f"Respond to: {user_input}", expected_output="A short and relevant reply.", agent=voice_agent
|
||||||
expected_output="A short and relevant reply.",
|
|
||||||
agent=voice_agent
|
|
||||||
)
|
)
|
||||||
crew = Crew(
|
crew = Crew(
|
||||||
agents=[voice_agent],
|
agents=[voice_agent],
|
||||||
@@ -175,22 +170,19 @@ def get_agent_response(user_input):
|
|||||||
process=Process.sequential,
|
process=Process.sequential,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
memory=True,
|
memory=True,
|
||||||
memory_config={
|
memory_config={"provider": "mem0", "config": {"user_id": USER_ID}},
|
||||||
"provider": "mem0",
|
|
||||||
"config": {"user_id": USER_ID}
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
result = crew.kickoff()
|
result = crew.kickoff()
|
||||||
|
|
||||||
# Extract the text response from the complex result object
|
# Extract the text response from the complex result object
|
||||||
if hasattr(result, 'raw'):
|
if hasattr(result, "raw"):
|
||||||
return result.raw
|
return result.raw
|
||||||
elif isinstance(result, dict) and 'raw' in result:
|
elif isinstance(result, dict) and "raw" in result:
|
||||||
return result['raw']
|
return result["raw"]
|
||||||
elif isinstance(result, dict) and 'tasks_output' in result:
|
elif isinstance(result, dict) and "tasks_output" in result:
|
||||||
outputs = result['tasks_output']
|
outputs = result["tasks_output"]
|
||||||
if outputs and isinstance(outputs, list) and len(outputs) > 0:
|
if outputs and isinstance(outputs, list) and len(outputs) > 0:
|
||||||
return outputs[0].get('raw', str(result))
|
return outputs[0].get("raw", str(result))
|
||||||
|
|
||||||
# Fallback to string representation if we can't extract the raw response
|
# Fallback to string representation if we can't extract the raw response
|
||||||
return str(result)
|
return str(result)
|
||||||
@@ -204,10 +196,7 @@ def get_agent_response(user_input):
|
|||||||
def speak_response(text):
|
def speak_response(text):
|
||||||
print(f"🤖 Agent: {text}")
|
print(f"🤖 Agent: {text}")
|
||||||
audio = tts_client.text_to_speech.convert(
|
audio = tts_client.text_to_speech.convert(
|
||||||
text=text,
|
text=text, voice_id="JBFqnCBsd6RMkjVDRZzb", model_id="eleven_multilingual_v2", output_format="mp3_44100_128"
|
||||||
voice_id="JBFqnCBsd6RMkjVDRZzb",
|
|
||||||
model_id="eleven_multilingual_v2",
|
|
||||||
output_format="mp3_44100_128"
|
|
||||||
)
|
)
|
||||||
play(audio)
|
play(audio)
|
||||||
|
|
||||||
@@ -220,7 +209,7 @@ def run_voice_agent():
|
|||||||
record_audio(tmp_audio.name)
|
record_audio(tmp_audio.name)
|
||||||
try:
|
try:
|
||||||
user_text = transcribe_whisper(tmp_audio.name)
|
user_text = transcribe_whisper(tmp_audio.name)
|
||||||
if user_text.lower() in ['exit', 'quit', 'stop']:
|
if user_text.lower() in ["exit", "quit", "stop"]:
|
||||||
print("👋 Exiting.")
|
print("👋 Exiting.")
|
||||||
break
|
break
|
||||||
response = get_agent_response(user_text)
|
response = get_agent_response(user_text)
|
||||||
|
|||||||
@@ -95,10 +95,7 @@ class MemoryClient:
|
|||||||
self.client = client
|
self.client = client
|
||||||
# Ensure the client has the correct base_url and headers
|
# Ensure the client has the correct base_url and headers
|
||||||
self.client.base_url = httpx.URL(self.host)
|
self.client.base_url = httpx.URL(self.host)
|
||||||
self.client.headers.update({
|
self.client.headers.update({"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id})
|
||||||
"Authorization": f"Token {self.api_key}",
|
|
||||||
"Mem0-User-ID": self.user_id
|
|
||||||
})
|
|
||||||
else:
|
else:
|
||||||
self.client = httpx.Client(
|
self.client = httpx.Client(
|
||||||
base_url=self.host,
|
base_url=self.host,
|
||||||
@@ -237,7 +234,9 @@ class MemoryClient:
|
|||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
if "metadata" in kwargs:
|
if "metadata" in kwargs:
|
||||||
del kwargs["metadata"]
|
del kwargs["metadata"]
|
||||||
capture_client_event("client.search", self, {"api_version": version, "keys": list(kwargs.keys()), "sync_type": "sync"})
|
capture_client_event(
|
||||||
|
"client.search", self, {"api_version": version, "keys": list(kwargs.keys()), "sync_type": "sync"}
|
||||||
|
)
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@api_error_handler
|
@api_error_handler
|
||||||
@@ -357,10 +356,7 @@ class MemoryClient:
|
|||||||
else:
|
else:
|
||||||
entities = self.users()
|
entities = self.users()
|
||||||
# Filter entities based on provided IDs using list comprehension
|
# Filter entities based on provided IDs using list comprehension
|
||||||
to_delete = [
|
to_delete = [{"type": entity["type"], "name": entity["name"]} for entity in entities["results"]]
|
||||||
{"type": entity["type"], "name": entity["name"]}
|
|
||||||
for entity in entities["results"]
|
|
||||||
]
|
|
||||||
|
|
||||||
params = self._prepare_params()
|
params = self._prepare_params()
|
||||||
|
|
||||||
@@ -373,7 +369,9 @@ class MemoryClient:
|
|||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
capture_client_event(
|
capture_client_event(
|
||||||
"client.delete_users", self, {"user_id": user_id, "agent_id": agent_id, "app_id": app_id, "run_id": run_id, "sync_type": "sync"}
|
"client.delete_users",
|
||||||
|
self,
|
||||||
|
{"user_id": user_id, "agent_id": agent_id, "app_id": app_id, "run_id": run_id, "sync_type": "sync"},
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"message": "Entity deleted successfully."
|
"message": "Entity deleted successfully."
|
||||||
@@ -454,7 +452,9 @@ class MemoryClient:
|
|||||||
"""
|
"""
|
||||||
response = self.client.post("/v1/exports/", json={"schema": schema, **self._prepare_params(kwargs)})
|
response = self.client.post("/v1/exports/", json={"schema": schema, **self._prepare_params(kwargs)})
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
capture_client_event("client.create_memory_export", self, {"schema": schema, "keys": list(kwargs.keys()), "sync_type": "sync"})
|
capture_client_event(
|
||||||
|
"client.create_memory_export", self, {"schema": schema, "keys": list(kwargs.keys()), "sync_type": "sync"}
|
||||||
|
)
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@api_error_handler
|
@api_error_handler
|
||||||
@@ -527,7 +527,11 @@ class MemoryClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
payload = self._prepare_params(
|
payload = self._prepare_params(
|
||||||
{"custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria}
|
{
|
||||||
|
"custom_instructions": custom_instructions,
|
||||||
|
"custom_categories": custom_categories,
|
||||||
|
"retrieval_criteria": retrieval_criteria,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
response = self.client.patch(
|
response = self.client.patch(
|
||||||
f"/api/v1/orgs/organizations/{self.org_id}/projects/{self.project_id}/",
|
f"/api/v1/orgs/organizations/{self.org_id}/projects/{self.project_id}/",
|
||||||
@@ -537,7 +541,12 @@ class MemoryClient:
|
|||||||
capture_client_event(
|
capture_client_event(
|
||||||
"client.update_project",
|
"client.update_project",
|
||||||
self,
|
self,
|
||||||
{"custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria, "sync_type": "sync"},
|
{
|
||||||
|
"custom_instructions": custom_instructions,
|
||||||
|
"custom_categories": custom_categories,
|
||||||
|
"retrieval_criteria": retrieval_criteria,
|
||||||
|
"sync_type": "sync",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@@ -750,10 +759,7 @@ class AsyncMemoryClient:
|
|||||||
self.async_client = client
|
self.async_client = client
|
||||||
# Ensure the client has the correct base_url and headers
|
# Ensure the client has the correct base_url and headers
|
||||||
self.async_client.base_url = httpx.URL(self.host)
|
self.async_client.base_url = httpx.URL(self.host)
|
||||||
self.async_client.headers.update({
|
self.async_client.headers.update({"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id})
|
||||||
"Authorization": f"Token {self.api_key}",
|
|
||||||
"Mem0-User-ID": self.user_id
|
|
||||||
})
|
|
||||||
else:
|
else:
|
||||||
self.async_client = httpx.AsyncClient(
|
self.async_client = httpx.AsyncClient(
|
||||||
base_url=self.host,
|
base_url=self.host,
|
||||||
@@ -768,7 +774,11 @@ class AsyncMemoryClient:
|
|||||||
"""Validate the API key by making a test request."""
|
"""Validate the API key by making a test request."""
|
||||||
try:
|
try:
|
||||||
params = self._prepare_params()
|
params = self._prepare_params()
|
||||||
response = requests.get(f"{self.host}/v1/ping/", headers={"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id}, params=params)
|
response = requests.get(
|
||||||
|
f"{self.host}/v1/ping/",
|
||||||
|
headers={"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id},
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@@ -973,10 +983,7 @@ class AsyncMemoryClient:
|
|||||||
else:
|
else:
|
||||||
entities = await self.users()
|
entities = await self.users()
|
||||||
# Filter entities based on provided IDs using list comprehension
|
# Filter entities based on provided IDs using list comprehension
|
||||||
to_delete = [
|
to_delete = [{"type": entity["type"], "name": entity["name"]} for entity in entities["results"]]
|
||||||
{"type": entity["type"], "name": entity["name"]}
|
|
||||||
for entity in entities["results"]
|
|
||||||
]
|
|
||||||
|
|
||||||
params = self._prepare_params()
|
params = self._prepare_params()
|
||||||
|
|
||||||
@@ -988,7 +995,11 @@ class AsyncMemoryClient:
|
|||||||
response = await self.async_client.delete(f"/v2/entities/{entity['type']}/{entity['name']}/", params=params)
|
response = await self.async_client.delete(f"/v2/entities/{entity['type']}/{entity['name']}/", params=params)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
capture_client_event("client.delete_users", self, {"user_id": user_id, "agent_id": agent_id, "app_id": app_id, "run_id": run_id, "sync_type": "async"})
|
capture_client_event(
|
||||||
|
"client.delete_users",
|
||||||
|
self,
|
||||||
|
{"user_id": user_id, "agent_id": agent_id, "app_id": app_id, "run_id": run_id, "sync_type": "async"},
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"message": "Entity deleted successfully."
|
"message": "Entity deleted successfully."
|
||||||
if (user_id or agent_id or app_id or run_id)
|
if (user_id or agent_id or app_id or run_id)
|
||||||
@@ -1091,8 +1102,10 @@ class AsyncMemoryClient:
|
|||||||
|
|
||||||
@api_error_handler
|
@api_error_handler
|
||||||
async def update_project(
|
async def update_project(
|
||||||
self, custom_instructions: Optional[str] = None, custom_categories: Optional[List[str]] = None,
|
self,
|
||||||
retrieval_criteria: Optional[List[Dict[str, Any]]] = None
|
custom_instructions: Optional[str] = None,
|
||||||
|
custom_categories: Optional[List[str]] = None,
|
||||||
|
retrieval_criteria: Optional[List[Dict[str, Any]]] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
if not (self.org_id and self.project_id):
|
if not (self.org_id and self.project_id):
|
||||||
raise ValueError("org_id and project_id must be set to update instructions or categories")
|
raise ValueError("org_id and project_id must be set to update instructions or categories")
|
||||||
@@ -1103,7 +1116,11 @@ class AsyncMemoryClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
payload = self._prepare_params(
|
payload = self._prepare_params(
|
||||||
{"custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria}
|
{
|
||||||
|
"custom_instructions": custom_instructions,
|
||||||
|
"custom_categories": custom_categories,
|
||||||
|
"retrieval_criteria": retrieval_criteria,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
response = await self.async_client.patch(
|
response = await self.async_client.patch(
|
||||||
f"/api/v1/orgs/organizations/{self.org_id}/projects/{self.project_id}/",
|
f"/api/v1/orgs/organizations/{self.org_id}/projects/{self.project_id}/",
|
||||||
@@ -1113,7 +1130,12 @@ class AsyncMemoryClient:
|
|||||||
capture_client_event(
|
capture_client_event(
|
||||||
"client.update_project",
|
"client.update_project",
|
||||||
self,
|
self,
|
||||||
{"custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria, "sync_type": "async"},
|
{
|
||||||
|
"custom_instructions": custom_instructions,
|
||||||
|
"custom_categories": custom_categories,
|
||||||
|
"retrieval_criteria": retrieval_criteria,
|
||||||
|
"sync_type": "async",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@@ -1174,4 +1196,3 @@ class AsyncMemoryClient:
|
|||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
capture_client_event("client.feedback", self, data, {"sync_type": "async"})
|
capture_client_event("client.feedback", self, data, {"sync_type": "async"})
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, Optional, Union, Type
|
from typing import Any, Dict, Optional, Type, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
@@ -7,33 +7,17 @@ class OpenSearchConfig(BaseModel):
|
|||||||
collection_name: str = Field("mem0", description="Name of the index")
|
collection_name: str = Field("mem0", description="Name of the index")
|
||||||
host: str = Field("localhost", description="OpenSearch host")
|
host: str = Field("localhost", description="OpenSearch host")
|
||||||
port: int = Field(9200, description="OpenSearch port")
|
port: int = Field(9200, description="OpenSearch port")
|
||||||
user: Optional[str] = Field(
|
user: Optional[str] = Field(None, description="Username for authentication")
|
||||||
None, description="Username for authentication"
|
password: Optional[str] = Field(None, description="Password for authentication")
|
||||||
)
|
api_key: Optional[str] = Field(None, description="API key for authentication (if applicable)")
|
||||||
password: Optional[str] = Field(
|
embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
|
||||||
None, description="Password for authentication"
|
verify_certs: bool = Field(False, description="Verify SSL certificates (default False for OpenSearch)")
|
||||||
)
|
use_ssl: bool = Field(False, description="Use SSL for connection (default False for OpenSearch)")
|
||||||
api_key: Optional[str] = Field(
|
http_auth: Optional[object] = Field(None, description="HTTP authentication method / AWS SigV4")
|
||||||
None, description="API key for authentication (if applicable)"
|
|
||||||
)
|
|
||||||
embedding_model_dims: int = Field(
|
|
||||||
1536, description="Dimension of the embedding vector"
|
|
||||||
)
|
|
||||||
verify_certs: bool = Field(
|
|
||||||
False, description="Verify SSL certificates (default False for OpenSearch)"
|
|
||||||
)
|
|
||||||
use_ssl: bool = Field(
|
|
||||||
False, description="Use SSL for connection (default False for OpenSearch)"
|
|
||||||
)
|
|
||||||
http_auth: Optional[object] = Field(
|
|
||||||
None, description="HTTP authentication method / AWS SigV4"
|
|
||||||
)
|
|
||||||
connection_class: Optional[Union[str, Type]] = Field(
|
connection_class: Optional[Union[str, Type]] = Field(
|
||||||
"RequestsHttpConnection", description="Connection class for OpenSearch"
|
"RequestsHttpConnection", description="Connection class for OpenSearch"
|
||||||
)
|
)
|
||||||
pool_maxsize: int = Field(
|
pool_maxsize: int = Field(20, description="Maximum number of connections in the pool")
|
||||||
20, description="Maximum number of connections in the pool"
|
|
||||||
)
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -41,7 +25,7 @@ class OpenSearchConfig(BaseModel):
|
|||||||
# Check if host is provided
|
# Check if host is provided
|
||||||
if not values.get("host"):
|
if not values.get("host"):
|
||||||
raise ValueError("Host must be provided for OpenSearch")
|
raise ValueError("Host must be provided for OpenSearch")
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@@ -52,7 +36,6 @@ class OpenSearchConfig(BaseModel):
|
|||||||
extra_fields = input_fields - allowed_fields
|
extra_fields = input_fields - allowed_fields
|
||||||
if extra_fields:
|
if extra_fields:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Extra fields not allowed: {', '.join(extra_fields)}. "
|
f"Extra fields not allowed: {', '.join(extra_fields)}. " f"Allowed fields: {', '.join(allowed_fields)}"
|
||||||
f"Allowed fields: {', '.join(allowed_fields)}"
|
|
||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|||||||
@@ -23,12 +23,12 @@ class AWSBedrockEmbedding(EmbeddingBase):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.config.model = self.config.model or "amazon.titan-embed-text-v1"
|
self.config.model = self.config.model or "amazon.titan-embed-text-v1"
|
||||||
|
|
||||||
# Get AWS config from environment variables or use defaults
|
# Get AWS config from environment variables or use defaults
|
||||||
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
|
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
|
||||||
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
|
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
|
||||||
aws_region = os.environ.get("AWS_REGION", "us-west-2")
|
aws_region = os.environ.get("AWS_REGION", "us-west-2")
|
||||||
|
|
||||||
# Check if AWS config is provided in the config
|
# Check if AWS config is provided in the config
|
||||||
if hasattr(self.config, "aws_access_key_id"):
|
if hasattr(self.config, "aws_access_key_id"):
|
||||||
aws_access_key = self.config.aws_access_key_id
|
aws_access_key = self.config.aws_access_key_id
|
||||||
@@ -36,7 +36,7 @@ class AWSBedrockEmbedding(EmbeddingBase):
|
|||||||
aws_secret_key = self.config.aws_secret_access_key
|
aws_secret_key = self.config.aws_secret_access_key
|
||||||
if hasattr(self.config, "aws_region"):
|
if hasattr(self.config, "aws_region"):
|
||||||
aws_region = self.config.aws_region
|
aws_region = self.config.aws_region
|
||||||
|
|
||||||
self.client = boto3.client(
|
self.client = boto3.client(
|
||||||
"bedrock-runtime",
|
"bedrock-runtime",
|
||||||
region_name=aws_region,
|
region_name=aws_region,
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ logging.getLogger("transformers").setLevel(logging.WARNING)
|
|||||||
logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
|
logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
|
||||||
logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
|
logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceEmbedding(EmbeddingBase):
|
class HuggingFaceEmbedding(EmbeddingBase):
|
||||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -22,7 +22,8 @@ class Neo4jConfig(BaseModel):
|
|||||||
if not url or not username or not password:
|
if not url or not username or not password:
|
||||||
raise ValueError("Please provide 'url', 'username' and 'password'.")
|
raise ValueError("Please provide 'url', 'username' and 'password'.")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
class MemgraphConfig(BaseModel):
|
class MemgraphConfig(BaseModel):
|
||||||
url: Optional[str] = Field(None, description="Host address for the graph database")
|
url: Optional[str] = Field(None, description="Host address for the graph database")
|
||||||
username: Optional[str] = Field(None, description="Username for the graph database")
|
username: Optional[str] = Field(None, description="Username for the graph database")
|
||||||
|
|||||||
@@ -20,18 +20,19 @@ def extract_provider(model: str) -> str:
|
|||||||
return provider
|
return provider
|
||||||
raise ValueError(f"Unknown provider in model: {model}")
|
raise ValueError(f"Unknown provider in model: {model}")
|
||||||
|
|
||||||
|
|
||||||
class AWSBedrockLLM(LLMBase):
|
class AWSBedrockLLM(LLMBase):
|
||||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
if not self.config.model:
|
if not self.config.model:
|
||||||
self.config.model = "anthropic.claude-3-5-sonnet-20240620-v1:0"
|
self.config.model = "anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||||
|
|
||||||
# Get AWS config from environment variables or use defaults
|
# Get AWS config from environment variables or use defaults
|
||||||
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
|
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
|
||||||
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
|
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
|
||||||
aws_region = os.environ.get("AWS_REGION", "us-west-2")
|
aws_region = os.environ.get("AWS_REGION", "us-west-2")
|
||||||
|
|
||||||
# Check if AWS config is provided in the config
|
# Check if AWS config is provided in the config
|
||||||
if hasattr(self.config, "aws_access_key_id"):
|
if hasattr(self.config, "aws_access_key_id"):
|
||||||
aws_access_key = self.config.aws_access_key_id
|
aws_access_key = self.config.aws_access_key_id
|
||||||
@@ -39,14 +40,14 @@ class AWSBedrockLLM(LLMBase):
|
|||||||
aws_secret_key = self.config.aws_secret_access_key
|
aws_secret_key = self.config.aws_secret_access_key
|
||||||
if hasattr(self.config, "aws_region"):
|
if hasattr(self.config, "aws_region"):
|
||||||
aws_region = self.config.aws_region
|
aws_region = self.config.aws_region
|
||||||
|
|
||||||
self.client = boto3.client(
|
self.client = boto3.client(
|
||||||
"bedrock-runtime",
|
"bedrock-runtime",
|
||||||
region_name=aws_region,
|
region_name=aws_region,
|
||||||
aws_access_key_id=aws_access_key if aws_access_key else None,
|
aws_access_key_id=aws_access_key if aws_access_key else None,
|
||||||
aws_secret_access_key=aws_secret_key if aws_secret_key else None,
|
aws_secret_access_key=aws_secret_key if aws_secret_key else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_kwargs = {
|
self.model_kwargs = {
|
||||||
"temperature": self.config.temperature,
|
"temperature": self.config.temperature,
|
||||||
"max_tokens_to_sample": self.config.max_tokens,
|
"max_tokens_to_sample": self.config.max_tokens,
|
||||||
@@ -145,7 +146,9 @@ class AWSBedrockLLM(LLMBase):
|
|||||||
input_body = {
|
input_body = {
|
||||||
"inputText": prompt,
|
"inputText": prompt,
|
||||||
"textGenerationConfig": {
|
"textGenerationConfig": {
|
||||||
"maxTokenCount": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000,
|
"maxTokenCount": self.model_kwargs["max_tokens_to_sample"]
|
||||||
|
or self.model_kwargs["max_tokens"]
|
||||||
|
or 5000,
|
||||||
"topP": self.model_kwargs["top_p"] or 0.9,
|
"topP": self.model_kwargs["top_p"] or 0.9,
|
||||||
"temperature": self.model_kwargs["temperature"] or 0.1,
|
"temperature": self.model_kwargs["temperature"] or 0.1,
|
||||||
},
|
},
|
||||||
@@ -243,22 +246,15 @@ class AWSBedrockLLM(LLMBase):
|
|||||||
body = json.dumps(input_body)
|
body = json.dumps(input_body)
|
||||||
|
|
||||||
if provider == "anthropic" or provider == "deepseek":
|
if provider == "anthropic" or provider == "deepseek":
|
||||||
|
|
||||||
input_body = {
|
input_body = {
|
||||||
"messages": [
|
"messages": [{"role": "user", "content": [{"type": "text", "text": prompt}]}],
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [{"type": "text", "text": prompt}]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"max_tokens": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000,
|
"max_tokens": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000,
|
||||||
"temperature": self.model_kwargs["temperature"] or 0.1,
|
"temperature": self.model_kwargs["temperature"] or 0.1,
|
||||||
"top_p": self.model_kwargs["top_p"] or 0.9,
|
"top_p": self.model_kwargs["top_p"] or 0.9,
|
||||||
"anthropic_version": "bedrock-2023-05-31",
|
"anthropic_version": "bedrock-2023-05-31",
|
||||||
}
|
}
|
||||||
|
|
||||||
body = json.dumps(input_body)
|
body = json.dumps(input_body)
|
||||||
|
|
||||||
|
|
||||||
response = self.client.invoke_model(
|
response = self.client.invoke_model(
|
||||||
body=body,
|
body=body,
|
||||||
@@ -272,6 +268,6 @@ class AWSBedrockLLM(LLMBase):
|
|||||||
modelId=self.config.model,
|
modelId=self.config.model,
|
||||||
accept="application/json",
|
accept="application/json",
|
||||||
contentType="application/json",
|
contentType="application/json",
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._parse_response(response, tools)
|
return self._parse_response(response, tools)
|
||||||
|
|||||||
@@ -34,17 +34,17 @@ from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
|
|||||||
|
|
||||||
|
|
||||||
def _build_filters_and_metadata(
|
def _build_filters_and_metadata(
|
||||||
*, # Enforce keyword-only arguments
|
*, # Enforce keyword-only arguments
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
agent_id: Optional[str] = None,
|
agent_id: Optional[str] = None,
|
||||||
run_id: Optional[str] = None,
|
run_id: Optional[str] = None,
|
||||||
actor_id: Optional[str] = None, # For query-time filtering
|
actor_id: Optional[str] = None, # For query-time filtering
|
||||||
input_metadata: Optional[Dict[str, Any]] = None,
|
input_metadata: Optional[Dict[str, Any]] = None,
|
||||||
input_filters: Optional[Dict[str, Any]] = None,
|
input_filters: Optional[Dict[str, Any]] = None,
|
||||||
) -> tuple[Dict[str, Any], Dict[str, Any]]:
|
) -> tuple[Dict[str, Any], Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Constructs metadata for storage and filters for querying based on session and actor identifiers.
|
Constructs metadata for storage and filters for querying based on session and actor identifiers.
|
||||||
|
|
||||||
This helper ties every memory/query to exactly one session id (`user_id`, `agent_id`, or `run_id`) and optionally narrows queries to a specific `actor_id`. It returns two dicts:
|
This helper ties every memory/query to exactly one session id (`user_id`, `agent_id`, or `run_id`) and optionally narrows queries to a specific `actor_id`. It returns two dicts:
|
||||||
|
|
||||||
|
|
||||||
@@ -78,10 +78,10 @@ def _build_filters_and_metadata(
|
|||||||
- effective_query_filters (Dict[str, Any]): Filters for querying memories,
|
- effective_query_filters (Dict[str, Any]): Filters for querying memories,
|
||||||
scoped to the determined session and potentially a resolved actor.
|
scoped to the determined session and potentially a resolved actor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
base_metadata_template = deepcopy(input_metadata) if input_metadata else {}
|
base_metadata_template = deepcopy(input_metadata) if input_metadata else {}
|
||||||
effective_query_filters = deepcopy(input_filters) if input_filters else {}
|
effective_query_filters = deepcopy(input_filters) if input_filters else {}
|
||||||
|
|
||||||
# ---------- resolve session id (mandatory) ----------
|
# ---------- resolve session id (mandatory) ----------
|
||||||
session_key, session_val = None, None
|
session_key, session_val = None, None
|
||||||
if user_id:
|
if user_id:
|
||||||
@@ -90,20 +90,20 @@ def _build_filters_and_metadata(
|
|||||||
session_key, session_val = "agent_id", agent_id
|
session_key, session_val = "agent_id", agent_id
|
||||||
elif run_id:
|
elif run_id:
|
||||||
session_key, session_val = "run_id", run_id
|
session_key, session_val = "run_id", run_id
|
||||||
|
|
||||||
if session_key is None:
|
if session_key is None:
|
||||||
raise ValueError("One of 'user_id', 'agent_id', or 'run_id' must be provided.")
|
raise ValueError("One of 'user_id', 'agent_id', or 'run_id' must be provided.")
|
||||||
|
|
||||||
base_metadata_template[session_key] = session_val
|
base_metadata_template[session_key] = session_val
|
||||||
effective_query_filters[session_key] = session_val
|
effective_query_filters[session_key] = session_val
|
||||||
|
|
||||||
# ---------- optional actor filter ----------
|
# ---------- optional actor filter ----------
|
||||||
resolved_actor_id = actor_id or effective_query_filters.get("actor_id")
|
resolved_actor_id = actor_id or effective_query_filters.get("actor_id")
|
||||||
if resolved_actor_id:
|
if resolved_actor_id:
|
||||||
effective_query_filters["actor_id"] = resolved_actor_id
|
effective_query_filters["actor_id"] = resolved_actor_id
|
||||||
|
|
||||||
return base_metadata_template, effective_query_filters
|
return base_metadata_template, effective_query_filters
|
||||||
|
|
||||||
|
|
||||||
setup_config()
|
setup_config()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -189,7 +189,7 @@ class Memory(MemoryBase):
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a new memory.
|
Create a new memory.
|
||||||
|
|
||||||
Adds new memories scoped to a single session id (e.g. `user_id`, `agent_id`, or `run_id`). One of those ids is required.
|
Adds new memories scoped to a single session id (e.g. `user_id`, `agent_id`, or `run_id`). One of those ids is required.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -208,7 +208,7 @@ class Memory(MemoryBase):
|
|||||||
creating procedural memories (typically requires 'agent_id'). Otherwise, memories
|
creating procedural memories (typically requires 'agent_id'). Otherwise, memories
|
||||||
are treated as general conversational/factual memories.memory_type (str, optional): Type of memory to create. Defaults to None. By default, it creates the short term memories and long term (semantic and episodic) memories. Pass "procedural_memory" to create procedural memories.
|
are treated as general conversational/factual memories.memory_type (str, optional): Type of memory to create. Defaults to None. By default, it creates the short term memories and long term (semantic and episodic) memories. Pass "procedural_memory" to create procedural memories.
|
||||||
prompt (str, optional): Prompt to use for the memory creation. Defaults to None.
|
prompt (str, optional): Prompt to use for the memory creation. Defaults to None.
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary containing the result of the memory addition operation, typically
|
dict: A dictionary containing the result of the memory addition operation, typically
|
||||||
@@ -216,14 +216,14 @@ class Memory(MemoryBase):
|
|||||||
and potentially "relations" if graph store is enabled.
|
and potentially "relations" if graph store is enabled.
|
||||||
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "event": "ADD"}]}`
|
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "event": "ADD"}]}`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
processed_metadata, effective_filters = _build_filters_and_metadata(
|
processed_metadata, effective_filters = _build_filters_and_metadata(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
input_metadata=metadata,
|
input_metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value:
|
if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} to create procedural memories."
|
f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} to create procedural memories."
|
||||||
@@ -231,10 +231,10 @@ class Memory(MemoryBase):
|
|||||||
|
|
||||||
if isinstance(messages, str):
|
if isinstance(messages, str):
|
||||||
messages = [{"role": "user", "content": messages}]
|
messages = [{"role": "user", "content": messages}]
|
||||||
|
|
||||||
elif isinstance(messages, dict):
|
elif isinstance(messages, dict):
|
||||||
messages = [messages]
|
messages = [messages]
|
||||||
|
|
||||||
elif not isinstance(messages, list):
|
elif not isinstance(messages, list):
|
||||||
raise ValueError("messages must be str, dict, or list[dict]")
|
raise ValueError("messages must be str, dict, or list[dict]")
|
||||||
|
|
||||||
@@ -255,7 +255,7 @@ class Memory(MemoryBase):
|
|||||||
|
|
||||||
vector_store_result = future1.result()
|
vector_store_result = future1.result()
|
||||||
graph_result = future2.result()
|
graph_result = future2.result()
|
||||||
|
|
||||||
if self.api_version == "v1.0":
|
if self.api_version == "v1.0":
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The current add API output format is deprecated. "
|
"The current add API output format is deprecated. "
|
||||||
@@ -277,21 +277,21 @@ class Memory(MemoryBase):
|
|||||||
def _add_to_vector_store(self, messages, metadata, filters, infer):
|
def _add_to_vector_store(self, messages, metadata, filters, infer):
|
||||||
if not infer:
|
if not infer:
|
||||||
returned_memories = []
|
returned_memories = []
|
||||||
for message_dict in messages:
|
for message_dict in messages:
|
||||||
if not isinstance(message_dict, dict) or \
|
if (
|
||||||
message_dict.get("role") is None or \
|
not isinstance(message_dict, dict)
|
||||||
message_dict.get("content") is None:
|
or message_dict.get("role") is None
|
||||||
|
or message_dict.get("content") is None
|
||||||
|
):
|
||||||
logger.warning(f"Skipping invalid message format: {message_dict}")
|
logger.warning(f"Skipping invalid message format: {message_dict}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if message_dict["role"] == "system":
|
if message_dict["role"] == "system":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
per_msg_meta = deepcopy(metadata)
|
per_msg_meta = deepcopy(metadata)
|
||||||
per_msg_meta["role"] = message_dict["role"]
|
per_msg_meta["role"] = message_dict["role"]
|
||||||
|
|
||||||
|
|
||||||
actor_name = message_dict.get("name")
|
actor_name = message_dict.get("name")
|
||||||
if actor_name:
|
if actor_name:
|
||||||
per_msg_meta["actor_id"] = actor_name
|
per_msg_meta["actor_id"] = actor_name
|
||||||
@@ -311,8 +311,8 @@ class Memory(MemoryBase):
|
|||||||
)
|
)
|
||||||
return returned_memories
|
return returned_memories
|
||||||
|
|
||||||
parsed_messages = parse_messages(messages)
|
parsed_messages = parse_messages(messages)
|
||||||
|
|
||||||
if self.config.custom_fact_extraction_prompt:
|
if self.config.custom_fact_extraction_prompt:
|
||||||
system_prompt = self.config.custom_fact_extraction_prompt
|
system_prompt = self.config.custom_fact_extraction_prompt
|
||||||
user_prompt = f"Input:\n{parsed_messages}"
|
user_prompt = f"Input:\n{parsed_messages}"
|
||||||
@@ -336,7 +336,7 @@ class Memory(MemoryBase):
|
|||||||
|
|
||||||
retrieved_old_memory = []
|
retrieved_old_memory = []
|
||||||
new_message_embeddings = {}
|
new_message_embeddings = {}
|
||||||
for new_mem in new_retrieved_facts:
|
for new_mem in new_retrieved_facts:
|
||||||
messages_embeddings = self.embedding_model.embed(new_mem, "add")
|
messages_embeddings = self.embedding_model.embed(new_mem, "add")
|
||||||
new_message_embeddings[new_mem] = messages_embeddings
|
new_message_embeddings[new_mem] = messages_embeddings
|
||||||
existing_memories = self.vector_store.search(
|
existing_memories = self.vector_store.search(
|
||||||
@@ -347,7 +347,7 @@ class Memory(MemoryBase):
|
|||||||
)
|
)
|
||||||
for mem in existing_memories:
|
for mem in existing_memories:
|
||||||
retrieved_old_memory.append({"id": mem.id, "text": mem.payload["data"]})
|
retrieved_old_memory.append({"id": mem.id, "text": mem.payload["data"]})
|
||||||
|
|
||||||
unique_data = {}
|
unique_data = {}
|
||||||
for item in retrieved_old_memory:
|
for item in retrieved_old_memory:
|
||||||
unique_data[item["id"]] = item
|
unique_data[item["id"]] = item
|
||||||
@@ -389,7 +389,7 @@ class Memory(MemoryBase):
|
|||||||
if not action_text:
|
if not action_text:
|
||||||
logging.info("Skipping memory entry because of empty `text` field.")
|
logging.info("Skipping memory entry because of empty `text` field.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
event_type = resp.get("event")
|
event_type = resp.get("event")
|
||||||
if event_type == "ADD":
|
if event_type == "ADD":
|
||||||
memory_id = self._create_memory(
|
memory_id = self._create_memory(
|
||||||
@@ -405,16 +405,23 @@ class Memory(MemoryBase):
|
|||||||
existing_embeddings=new_message_embeddings,
|
existing_embeddings=new_message_embeddings,
|
||||||
metadata=deepcopy(metadata),
|
metadata=deepcopy(metadata),
|
||||||
)
|
)
|
||||||
returned_memories.append({
|
returned_memories.append(
|
||||||
"id": temp_uuid_mapping[resp.get("id")], "memory": action_text,
|
{
|
||||||
"event": event_type, "previous_memory": resp.get("old_memory"),
|
"id": temp_uuid_mapping[resp.get("id")],
|
||||||
})
|
"memory": action_text,
|
||||||
|
"event": event_type,
|
||||||
|
"previous_memory": resp.get("old_memory"),
|
||||||
|
}
|
||||||
|
)
|
||||||
elif event_type == "DELETE":
|
elif event_type == "DELETE":
|
||||||
self._delete_memory(memory_id=temp_uuid_mapping[resp.get("id")])
|
self._delete_memory(memory_id=temp_uuid_mapping[resp.get("id")])
|
||||||
returned_memories.append({
|
returned_memories.append(
|
||||||
"id": temp_uuid_mapping[resp.get("id")], "memory": action_text,
|
{
|
||||||
"event": event_type,
|
"id": temp_uuid_mapping[resp.get("id")],
|
||||||
})
|
"memory": action_text,
|
||||||
|
"event": event_type,
|
||||||
|
}
|
||||||
|
)
|
||||||
elif event_type == "NONE":
|
elif event_type == "NONE":
|
||||||
logging.info("NOOP for Memory.")
|
logging.info("NOOP for Memory.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -462,11 +469,8 @@ class Memory(MemoryBase):
|
|||||||
"actor_id",
|
"actor_id",
|
||||||
"role",
|
"role",
|
||||||
]
|
]
|
||||||
|
|
||||||
core_and_promoted_keys = {
|
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
|
||||||
"data", "hash", "created_at", "updated_at", "id",
|
|
||||||
*promoted_payload_keys
|
|
||||||
}
|
|
||||||
|
|
||||||
result_item = MemoryItem(
|
result_item = MemoryItem(
|
||||||
id=memory.id,
|
id=memory.id,
|
||||||
@@ -479,18 +483,16 @@ class Memory(MemoryBase):
|
|||||||
for key in promoted_payload_keys:
|
for key in promoted_payload_keys:
|
||||||
if key in memory.payload:
|
if key in memory.payload:
|
||||||
result_item[key] = memory.payload[key]
|
result_item[key] = memory.payload[key]
|
||||||
|
|
||||||
additional_metadata = {
|
additional_metadata = {k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys}
|
||||||
k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys
|
|
||||||
}
|
|
||||||
if additional_metadata:
|
if additional_metadata:
|
||||||
result_item["metadata"] = additional_metadata
|
result_item["metadata"] = additional_metadata
|
||||||
|
|
||||||
return result_item
|
return result_item
|
||||||
|
|
||||||
def get_all(
|
def get_all(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
agent_id: Optional[str] = None,
|
agent_id: Optional[str] = None,
|
||||||
run_id: Optional[str] = None,
|
run_id: Optional[str] = None,
|
||||||
@@ -505,7 +507,7 @@ class Memory(MemoryBase):
|
|||||||
agent_id (str, optional): agent id
|
agent_id (str, optional): agent id
|
||||||
run_id (str, optional): run id
|
run_id (str, optional): run id
|
||||||
filters (dict, optional): Additional custom key-value filters to apply to the search.
|
filters (dict, optional): Additional custom key-value filters to apply to the search.
|
||||||
These are merged with the ID-based scoping filters. For example,
|
These are merged with the ID-based scoping filters. For example,
|
||||||
`filters={"actor_id": "some_user"}`.
|
`filters={"actor_id": "some_user"}`.
|
||||||
limit (int, optional): The maximum number of memories to return. Defaults to 100.
|
limit (int, optional): The maximum number of memories to return. Defaults to 100.
|
||||||
|
|
||||||
@@ -515,21 +517,16 @@ class Memory(MemoryBase):
|
|||||||
it might return a direct list (see deprecation warning).
|
it might return a direct list (see deprecation warning).
|
||||||
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", ...}]}`
|
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", ...}]}`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_, effective_filters = _build_filters_and_metadata(
|
_, effective_filters = _build_filters_and_metadata(
|
||||||
user_id=user_id,
|
user_id=user_id, agent_id=agent_id, run_id=run_id, input_filters=filters
|
||||||
agent_id=agent_id,
|
|
||||||
run_id=run_id,
|
|
||||||
input_filters=filters
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
|
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
|
||||||
raise ValueError("At least one of 'user_id', 'agent_id', or 'run_id' must be specified.")
|
raise ValueError("At least one of 'user_id', 'agent_id', or 'run_id' must be specified.")
|
||||||
|
|
||||||
capture_event(
|
capture_event(
|
||||||
"mem0.get_all",
|
"mem0.get_all", self, {"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "sync"}
|
||||||
self,
|
|
||||||
{"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "sync"}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
@@ -542,9 +539,9 @@ class Memory(MemoryBase):
|
|||||||
[future_memories, future_graph_entities] if future_graph_entities else [future_memories]
|
[future_memories, future_graph_entities] if future_graph_entities else [future_memories]
|
||||||
)
|
)
|
||||||
|
|
||||||
all_memories_result = future_memories.result()
|
all_memories_result = future_memories.result()
|
||||||
graph_entities_result = future_graph_entities.result() if future_graph_entities else None
|
graph_entities_result = future_graph_entities.result() if future_graph_entities else None
|
||||||
|
|
||||||
if self.enable_graph:
|
if self.enable_graph:
|
||||||
return {"results": all_memories_result, "relations": graph_entities_result}
|
return {"results": all_memories_result, "relations": graph_entities_result}
|
||||||
|
|
||||||
@@ -556,26 +553,27 @@ class Memory(MemoryBase):
|
|||||||
category=DeprecationWarning,
|
category=DeprecationWarning,
|
||||||
stacklevel=2,
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
return all_memories_result
|
return all_memories_result
|
||||||
else:
|
else:
|
||||||
return {"results": all_memories_result}
|
return {"results": all_memories_result}
|
||||||
|
|
||||||
def _get_all_from_vector_store(self, filters, limit):
|
def _get_all_from_vector_store(self, filters, limit):
|
||||||
memories_result = self.vector_store.list(filters=filters, limit=limit)
|
memories_result = self.vector_store.list(filters=filters, limit=limit)
|
||||||
actual_memories = memories_result[0] if isinstance(memories_result, tuple) and len(memories_result) > 0 else memories_result
|
actual_memories = (
|
||||||
|
memories_result[0] if isinstance(memories_result, tuple) and len(memories_result) > 0 else memories_result
|
||||||
|
)
|
||||||
|
|
||||||
promoted_payload_keys = [
|
promoted_payload_keys = [
|
||||||
"user_id", "agent_id", "run_id",
|
"user_id",
|
||||||
|
"agent_id",
|
||||||
|
"run_id",
|
||||||
"actor_id",
|
"actor_id",
|
||||||
"role",
|
"role",
|
||||||
]
|
]
|
||||||
core_and_promoted_keys = {
|
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
|
||||||
"data", "hash", "created_at", "updated_at", "id",
|
|
||||||
*promoted_payload_keys
|
|
||||||
}
|
|
||||||
|
|
||||||
formatted_memories = []
|
formatted_memories = []
|
||||||
for mem in actual_memories:
|
for mem in actual_memories:
|
||||||
memory_item_dict = MemoryItem(
|
memory_item_dict = MemoryItem(
|
||||||
id=mem.id,
|
id=mem.id,
|
||||||
memory=mem.payload["data"],
|
memory=mem.payload["data"],
|
||||||
@@ -587,15 +585,13 @@ class Memory(MemoryBase):
|
|||||||
for key in promoted_payload_keys:
|
for key in promoted_payload_keys:
|
||||||
if key in mem.payload:
|
if key in mem.payload:
|
||||||
memory_item_dict[key] = mem.payload[key]
|
memory_item_dict[key] = mem.payload[key]
|
||||||
|
|
||||||
additional_metadata = {
|
additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys}
|
||||||
k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys
|
|
||||||
}
|
|
||||||
if additional_metadata:
|
if additional_metadata:
|
||||||
memory_item_dict["metadata"] = additional_metadata
|
memory_item_dict["metadata"] = additional_metadata
|
||||||
|
|
||||||
formatted_memories.append(memory_item_dict)
|
formatted_memories.append(memory_item_dict)
|
||||||
|
|
||||||
return formatted_memories
|
return formatted_memories
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
@@ -624,12 +620,9 @@ class Memory(MemoryBase):
|
|||||||
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "score": 0.8, ...}]}`
|
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "score": 0.8, ...}]}`
|
||||||
"""
|
"""
|
||||||
_, effective_filters = _build_filters_and_metadata(
|
_, effective_filters = _build_filters_and_metadata(
|
||||||
user_id=user_id,
|
user_id=user_id, agent_id=agent_id, run_id=run_id, input_filters=filters
|
||||||
agent_id=agent_id,
|
|
||||||
run_id=run_id,
|
|
||||||
input_filters=filters
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
|
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
|
||||||
raise ValueError("At least one of 'user_id', 'agent_id', or 'run_id' must be specified.")
|
raise ValueError("At least one of 'user_id', 'agent_id', or 'run_id' must be specified.")
|
||||||
|
|
||||||
@@ -651,7 +644,7 @@ class Memory(MemoryBase):
|
|||||||
|
|
||||||
original_memories = future_memories.result()
|
original_memories = future_memories.result()
|
||||||
graph_entities = future_graph_entities.result() if future_graph_entities else None
|
graph_entities = future_graph_entities.result() if future_graph_entities else None
|
||||||
|
|
||||||
if self.enable_graph:
|
if self.enable_graph:
|
||||||
return {"results": original_memories, "relations": graph_entities}
|
return {"results": original_memories, "relations": graph_entities}
|
||||||
|
|
||||||
@@ -678,11 +671,8 @@ class Memory(MemoryBase):
|
|||||||
"actor_id",
|
"actor_id",
|
||||||
"role",
|
"role",
|
||||||
]
|
]
|
||||||
|
|
||||||
core_and_promoted_keys = {
|
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
|
||||||
"data", "hash", "created_at", "updated_at", "id",
|
|
||||||
*promoted_payload_keys
|
|
||||||
}
|
|
||||||
|
|
||||||
original_memories = []
|
original_memories = []
|
||||||
for mem in memories:
|
for mem in memories:
|
||||||
@@ -693,18 +683,16 @@ class Memory(MemoryBase):
|
|||||||
created_at=mem.payload.get("created_at"),
|
created_at=mem.payload.get("created_at"),
|
||||||
updated_at=mem.payload.get("updated_at"),
|
updated_at=mem.payload.get("updated_at"),
|
||||||
score=mem.score,
|
score=mem.score,
|
||||||
).model_dump()
|
).model_dump()
|
||||||
|
|
||||||
for key in promoted_payload_keys:
|
for key in promoted_payload_keys:
|
||||||
if key in mem.payload:
|
if key in mem.payload:
|
||||||
memory_item_dict[key] = mem.payload[key]
|
memory_item_dict[key] = mem.payload[key]
|
||||||
|
|
||||||
additional_metadata = {
|
additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys}
|
||||||
k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys
|
|
||||||
}
|
|
||||||
if additional_metadata:
|
if additional_metadata:
|
||||||
memory_item_dict["metadata"] = additional_metadata
|
memory_item_dict["metadata"] = additional_metadata
|
||||||
|
|
||||||
original_memories.append(memory_item_dict)
|
original_memories.append(memory_item_dict)
|
||||||
|
|
||||||
return original_memories
|
return original_memories
|
||||||
@@ -738,7 +726,7 @@ class Memory(MemoryBase):
|
|||||||
self._delete_memory(memory_id)
|
self._delete_memory(memory_id)
|
||||||
return {"message": "Memory deleted successfully!"}
|
return {"message": "Memory deleted successfully!"}
|
||||||
|
|
||||||
def delete_all(self, user_id:Optional[str]=None, agent_id:Optional[str]=None, run_id:Optional[str]=None):
|
def delete_all(self, user_id: Optional[str] = None, agent_id: Optional[str] = None, run_id: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
Delete all memories.
|
Delete all memories.
|
||||||
|
|
||||||
@@ -860,11 +848,11 @@ class Memory(MemoryBase):
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.error(f"Error getting memory with ID {memory_id} during update.")
|
logger.error(f"Error getting memory with ID {memory_id} during update.")
|
||||||
raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'")
|
raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'")
|
||||||
|
|
||||||
prev_value = existing_memory.payload.get("data")
|
prev_value = existing_memory.payload.get("data")
|
||||||
|
|
||||||
new_metadata = deepcopy(metadata) if metadata is not None else {}
|
new_metadata = deepcopy(metadata) if metadata is not None else {}
|
||||||
|
|
||||||
new_metadata["data"] = data
|
new_metadata["data"] = data
|
||||||
new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
|
new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
|
||||||
new_metadata["created_at"] = existing_memory.payload.get("created_at")
|
new_metadata["created_at"] = existing_memory.payload.get("created_at")
|
||||||
@@ -875,7 +863,7 @@ class Memory(MemoryBase):
|
|||||||
if "agent_id" in existing_memory.payload:
|
if "agent_id" in existing_memory.payload:
|
||||||
new_metadata["agent_id"] = existing_memory.payload["agent_id"]
|
new_metadata["agent_id"] = existing_memory.payload["agent_id"]
|
||||||
if "run_id" in existing_memory.payload:
|
if "run_id" in existing_memory.payload:
|
||||||
new_metadata["run_id"] = existing_memory.payload["run_id"]
|
new_metadata["run_id"] = existing_memory.payload["run_id"]
|
||||||
if "actor_id" in existing_memory.payload:
|
if "actor_id" in existing_memory.payload:
|
||||||
new_metadata["actor_id"] = existing_memory.payload["actor_id"]
|
new_metadata["actor_id"] = existing_memory.payload["actor_id"]
|
||||||
if "role" in existing_memory.payload:
|
if "role" in existing_memory.payload:
|
||||||
@@ -885,14 +873,14 @@ class Memory(MemoryBase):
|
|||||||
embeddings = existing_embeddings[data]
|
embeddings = existing_embeddings[data]
|
||||||
else:
|
else:
|
||||||
embeddings = self.embedding_model.embed(data, "update")
|
embeddings = self.embedding_model.embed(data, "update")
|
||||||
|
|
||||||
self.vector_store.update(
|
self.vector_store.update(
|
||||||
vector_id=memory_id,
|
vector_id=memory_id,
|
||||||
vector=embeddings,
|
vector=embeddings,
|
||||||
payload=new_metadata,
|
payload=new_metadata,
|
||||||
)
|
)
|
||||||
logger.info(f"Updating memory with ID {memory_id=} with {data=}")
|
logger.info(f"Updating memory with ID {memory_id=} with {data=}")
|
||||||
|
|
||||||
self.db.add_history(
|
self.db.add_history(
|
||||||
memory_id,
|
memory_id,
|
||||||
prev_value,
|
prev_value,
|
||||||
@@ -1037,12 +1025,9 @@ class AsyncMemory(MemoryBase):
|
|||||||
dict: A dictionary containing the result of the memory addition operation.
|
dict: A dictionary containing the result of the memory addition operation.
|
||||||
"""
|
"""
|
||||||
processed_metadata, effective_filters = _build_filters_and_metadata(
|
processed_metadata, effective_filters = _build_filters_and_metadata(
|
||||||
user_id=user_id,
|
user_id=user_id, agent_id=agent_id, run_id=run_id, input_metadata=metadata
|
||||||
agent_id=agent_id,
|
|
||||||
run_id=run_id,
|
|
||||||
input_metadata=metadata
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value:
|
if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} to create procedural memories."
|
f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} to create procedural memories."
|
||||||
@@ -1050,15 +1035,17 @@ class AsyncMemory(MemoryBase):
|
|||||||
|
|
||||||
if isinstance(messages, str):
|
if isinstance(messages, str):
|
||||||
messages = [{"role": "user", "content": messages}]
|
messages = [{"role": "user", "content": messages}]
|
||||||
|
|
||||||
elif isinstance(messages, dict):
|
elif isinstance(messages, dict):
|
||||||
messages = [messages]
|
messages = [messages]
|
||||||
|
|
||||||
elif not isinstance(messages, list):
|
elif not isinstance(messages, list):
|
||||||
raise ValueError("messages must be str, dict, or list[dict]")
|
raise ValueError("messages must be str, dict, or list[dict]")
|
||||||
|
|
||||||
if agent_id is not None and memory_type == MemoryType.PROCEDURAL.value:
|
if agent_id is not None and memory_type == MemoryType.PROCEDURAL.value:
|
||||||
results = await self._create_procedural_memory(messages, metadata=processed_metadata, prompt=prompt, llm=llm)
|
results = await self._create_procedural_memory(
|
||||||
|
messages, metadata=processed_metadata, prompt=prompt, llm=llm
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
if self.config.llm.config.get("enable_vision"):
|
if self.config.llm.config.get("enable_vision"):
|
||||||
@@ -1066,7 +1053,9 @@ class AsyncMemory(MemoryBase):
|
|||||||
else:
|
else:
|
||||||
messages = parse_vision_messages(messages)
|
messages = parse_vision_messages(messages)
|
||||||
|
|
||||||
vector_store_task = asyncio.create_task(self._add_to_vector_store(messages, processed_metadata, effective_filters, infer))
|
vector_store_task = asyncio.create_task(
|
||||||
|
self._add_to_vector_store(messages, processed_metadata, effective_filters, infer)
|
||||||
|
)
|
||||||
graph_task = asyncio.create_task(self._add_to_graph(messages, effective_filters))
|
graph_task = asyncio.create_task(self._add_to_graph(messages, effective_filters))
|
||||||
|
|
||||||
vector_store_result, graph_result = await asyncio.gather(vector_store_task, graph_task)
|
vector_store_result, graph_result = await asyncio.gather(vector_store_task, graph_task)
|
||||||
@@ -1090,8 +1079,8 @@ class AsyncMemory(MemoryBase):
|
|||||||
return {"results": vector_store_result}
|
return {"results": vector_store_result}
|
||||||
|
|
||||||
async def _add_to_vector_store(
|
async def _add_to_vector_store(
|
||||||
self,
|
self,
|
||||||
messages: list,
|
messages: list,
|
||||||
metadata: dict,
|
metadata: dict,
|
||||||
filters: dict,
|
filters: dict,
|
||||||
infer: bool,
|
infer: bool,
|
||||||
@@ -1099,9 +1088,11 @@ class AsyncMemory(MemoryBase):
|
|||||||
if not infer:
|
if not infer:
|
||||||
returned_memories = []
|
returned_memories = []
|
||||||
for message_dict in messages:
|
for message_dict in messages:
|
||||||
if not isinstance(message_dict, dict) or \
|
if (
|
||||||
message_dict.get("role") is None or \
|
not isinstance(message_dict, dict)
|
||||||
message_dict.get("content") is None:
|
or message_dict.get("role") is None
|
||||||
|
or message_dict.get("content") is None
|
||||||
|
):
|
||||||
logger.warning(f"Skipping invalid message format (async): {message_dict}")
|
logger.warning(f"Skipping invalid message format (async): {message_dict}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -1110,20 +1101,24 @@ class AsyncMemory(MemoryBase):
|
|||||||
|
|
||||||
per_msg_meta = deepcopy(metadata)
|
per_msg_meta = deepcopy(metadata)
|
||||||
per_msg_meta["role"] = message_dict["role"]
|
per_msg_meta["role"] = message_dict["role"]
|
||||||
|
|
||||||
actor_name = message_dict.get("name")
|
actor_name = message_dict.get("name")
|
||||||
if actor_name:
|
if actor_name:
|
||||||
per_msg_meta["actor_id"] = actor_name
|
per_msg_meta["actor_id"] = actor_name
|
||||||
|
|
||||||
msg_content = message_dict["content"]
|
msg_content = message_dict["content"]
|
||||||
msg_embeddings = await asyncio.to_thread(self.embedding_model.embed, msg_content, "add")
|
msg_embeddings = await asyncio.to_thread(self.embedding_model.embed, msg_content, "add")
|
||||||
mem_id = await self._create_memory(msg_content, msg_embeddings, per_msg_meta)
|
mem_id = await self._create_memory(msg_content, msg_embeddings, per_msg_meta)
|
||||||
|
|
||||||
returned_memories.append({
|
returned_memories.append(
|
||||||
"id": mem_id, "memory": msg_content, "event": "ADD",
|
{
|
||||||
"actor_id": actor_name if actor_name else None,
|
"id": mem_id,
|
||||||
"role": message_dict["role"],
|
"memory": msg_content,
|
||||||
})
|
"event": "ADD",
|
||||||
|
"actor_id": actor_name if actor_name else None,
|
||||||
|
"role": message_dict["role"],
|
||||||
|
}
|
||||||
|
)
|
||||||
return returned_memories
|
return returned_memories
|
||||||
|
|
||||||
parsed_messages = parse_messages(messages)
|
parsed_messages = parse_messages(messages)
|
||||||
@@ -1142,17 +1137,21 @@ class AsyncMemory(MemoryBase):
|
|||||||
response = remove_code_blocks(response)
|
response = remove_code_blocks(response)
|
||||||
new_retrieved_facts = json.loads(response)["facts"]
|
new_retrieved_facts = json.loads(response)["facts"]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in new_retrieved_facts: {e}"); new_retrieved_facts = []
|
logging.error(f"Error in new_retrieved_facts: {e}")
|
||||||
|
new_retrieved_facts = []
|
||||||
|
|
||||||
retrieved_old_memory = []
|
retrieved_old_memory = []
|
||||||
new_message_embeddings = {}
|
new_message_embeddings = {}
|
||||||
|
|
||||||
async def process_fact_for_search(new_mem_content):
|
async def process_fact_for_search(new_mem_content):
|
||||||
embeddings = await asyncio.to_thread(self.embedding_model.embed, new_mem_content, "add")
|
embeddings = await asyncio.to_thread(self.embedding_model.embed, new_mem_content, "add")
|
||||||
new_message_embeddings[new_mem_content] = embeddings
|
new_message_embeddings[new_mem_content] = embeddings
|
||||||
existing_mems = await asyncio.to_thread(
|
existing_mems = await asyncio.to_thread(
|
||||||
self.vector_store.search, query=new_mem_content, vectors=embeddings,
|
self.vector_store.search,
|
||||||
limit=5, filters=filters, # 'filters' is query_filters_for_inference
|
query=new_mem_content,
|
||||||
|
vectors=embeddings,
|
||||||
|
limit=5,
|
||||||
|
filters=filters, # 'filters' is query_filters_for_inference
|
||||||
)
|
)
|
||||||
return [{"id": mem.id, "text": mem.payload["data"]} for mem in existing_mems]
|
return [{"id": mem.id, "text": mem.payload["data"]} for mem in existing_mems]
|
||||||
|
|
||||||
@@ -1160,9 +1159,10 @@ class AsyncMemory(MemoryBase):
|
|||||||
search_results_list = await asyncio.gather(*search_tasks)
|
search_results_list = await asyncio.gather(*search_tasks)
|
||||||
for result_group in search_results_list:
|
for result_group in search_results_list:
|
||||||
retrieved_old_memory.extend(result_group)
|
retrieved_old_memory.extend(result_group)
|
||||||
|
|
||||||
unique_data = {}
|
unique_data = {}
|
||||||
for item in retrieved_old_memory: unique_data[item["id"]] = item
|
for item in retrieved_old_memory:
|
||||||
|
unique_data[item["id"]] = item
|
||||||
retrieved_old_memory = list(unique_data.values())
|
retrieved_old_memory = list(unique_data.values())
|
||||||
logging.info(f"Total existing memories: {len(retrieved_old_memory)}")
|
logging.info(f"Total existing memories: {len(retrieved_old_memory)}")
|
||||||
temp_uuid_mapping = {}
|
temp_uuid_mapping = {}
|
||||||
@@ -1180,35 +1180,45 @@ class AsyncMemory(MemoryBase):
|
|||||||
response_format={"type": "json_object"},
|
response_format={"type": "json_object"},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in new memory actions response: {e}"); response = ""
|
logging.error(f"Error in new memory actions response: {e}")
|
||||||
|
response = ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = remove_code_blocks(response)
|
response = remove_code_blocks(response)
|
||||||
new_memories_with_actions = json.loads(response)
|
new_memories_with_actions = json.loads(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Invalid JSON response: {e}"); new_memories_with_actions = {}
|
logging.error(f"Invalid JSON response: {e}")
|
||||||
|
new_memories_with_actions = {}
|
||||||
|
|
||||||
returned_memories = []
|
returned_memories = []
|
||||||
try:
|
try:
|
||||||
memory_tasks = []
|
memory_tasks = []
|
||||||
for resp in new_memories_with_actions.get("memory", []):
|
for resp in new_memories_with_actions.get("memory", []):
|
||||||
logging.info(resp)
|
logging.info(resp)
|
||||||
try:
|
try:
|
||||||
action_text = resp.get("text")
|
action_text = resp.get("text")
|
||||||
if not action_text: continue
|
if not action_text:
|
||||||
|
continue
|
||||||
event_type = resp.get("event")
|
event_type = resp.get("event")
|
||||||
|
|
||||||
if event_type == "ADD":
|
if event_type == "ADD":
|
||||||
task = asyncio.create_task(self._create_memory(
|
task = asyncio.create_task(
|
||||||
data=action_text, existing_embeddings=new_message_embeddings,
|
self._create_memory(
|
||||||
metadata=deepcopy(metadata)
|
data=action_text,
|
||||||
))
|
existing_embeddings=new_message_embeddings,
|
||||||
|
metadata=deepcopy(metadata),
|
||||||
|
)
|
||||||
|
)
|
||||||
memory_tasks.append((task, resp, "ADD", None))
|
memory_tasks.append((task, resp, "ADD", None))
|
||||||
elif event_type == "UPDATE":
|
elif event_type == "UPDATE":
|
||||||
task = asyncio.create_task(self._update_memory(
|
task = asyncio.create_task(
|
||||||
memory_id=temp_uuid_mapping[resp["id"]], data=action_text,
|
self._update_memory(
|
||||||
existing_embeddings=new_message_embeddings, metadata=deepcopy(metadata)
|
memory_id=temp_uuid_mapping[resp["id"]],
|
||||||
))
|
data=action_text,
|
||||||
|
existing_embeddings=new_message_embeddings,
|
||||||
|
metadata=deepcopy(metadata),
|
||||||
|
)
|
||||||
|
)
|
||||||
memory_tasks.append((task, resp, "UPDATE", temp_uuid_mapping[resp["id"]]))
|
memory_tasks.append((task, resp, "UPDATE", temp_uuid_mapping[resp["id"]]))
|
||||||
elif event_type == "DELETE":
|
elif event_type == "DELETE":
|
||||||
task = asyncio.create_task(self._delete_memory(memory_id=temp_uuid_mapping[resp.get("id")]))
|
task = asyncio.create_task(self._delete_memory(memory_id=temp_uuid_mapping[resp.get("id")]))
|
||||||
@@ -1217,31 +1227,30 @@ class AsyncMemory(MemoryBase):
|
|||||||
logging.info("NOOP for Memory (async).")
|
logging.info("NOOP for Memory (async).")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error processing memory action (async): {resp}, Error: {e}")
|
logging.error(f"Error processing memory action (async): {resp}, Error: {e}")
|
||||||
|
|
||||||
for task, resp, event_type, mem_id in memory_tasks:
|
for task, resp, event_type, mem_id in memory_tasks:
|
||||||
try:
|
try:
|
||||||
result_id = await task
|
result_id = await task
|
||||||
if event_type == "ADD":
|
if event_type == "ADD":
|
||||||
returned_memories.append({
|
returned_memories.append({"id": result_id, "memory": resp.get("text"), "event": event_type})
|
||||||
"id": result_id, "memory": resp.get("text"), "event": event_type
|
|
||||||
})
|
|
||||||
elif event_type == "UPDATE":
|
elif event_type == "UPDATE":
|
||||||
returned_memories.append({
|
returned_memories.append(
|
||||||
"id": mem_id, "memory": resp.get("text"),
|
{
|
||||||
"event": event_type, "previous_memory": resp.get("old_memory")
|
"id": mem_id,
|
||||||
})
|
"memory": resp.get("text"),
|
||||||
|
"event": event_type,
|
||||||
|
"previous_memory": resp.get("old_memory"),
|
||||||
|
}
|
||||||
|
)
|
||||||
elif event_type == "DELETE":
|
elif event_type == "DELETE":
|
||||||
returned_memories.append({
|
returned_memories.append({"id": mem_id, "memory": resp.get("text"), "event": event_type})
|
||||||
"id": mem_id, "memory": resp.get("text"), "event": event_type
|
|
||||||
})
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error awaiting memory task (async): {e}")
|
logging.error(f"Error awaiting memory task (async): {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in memory processing loop (async): {e}")
|
logging.error(f"Error in memory processing loop (async): {e}")
|
||||||
|
|
||||||
capture_event(
|
capture_event(
|
||||||
"mem0.add", self,
|
"mem0.add", self, {"version": self.api_version, "keys": list(filters.keys()), "sync_type": "async"}
|
||||||
{"version": self.api_version, "keys": list(filters.keys()), "sync_type": "async"}
|
|
||||||
)
|
)
|
||||||
return returned_memories
|
return returned_memories
|
||||||
|
|
||||||
@@ -1272,17 +1281,14 @@ class AsyncMemory(MemoryBase):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
promoted_payload_keys = [
|
promoted_payload_keys = [
|
||||||
"user_id",
|
"user_id",
|
||||||
"agent_id",
|
"agent_id",
|
||||||
"run_id",
|
"run_id",
|
||||||
"actor_id",
|
"actor_id",
|
||||||
"role",
|
"role",
|
||||||
]
|
]
|
||||||
|
|
||||||
core_and_promoted_keys = {
|
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
|
||||||
"data", "hash", "created_at", "updated_at", "id",
|
|
||||||
*promoted_payload_keys
|
|
||||||
}
|
|
||||||
|
|
||||||
result_item = MemoryItem(
|
result_item = MemoryItem(
|
||||||
id=memory.id,
|
id=memory.id,
|
||||||
@@ -1295,18 +1301,16 @@ class AsyncMemory(MemoryBase):
|
|||||||
for key in promoted_payload_keys:
|
for key in promoted_payload_keys:
|
||||||
if key in memory.payload:
|
if key in memory.payload:
|
||||||
result_item[key] = memory.payload[key]
|
result_item[key] = memory.payload[key]
|
||||||
|
|
||||||
additional_metadata = {
|
additional_metadata = {k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys}
|
||||||
k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys
|
|
||||||
}
|
|
||||||
if additional_metadata:
|
if additional_metadata:
|
||||||
result_item["metadata"] = additional_metadata
|
result_item["metadata"] = additional_metadata
|
||||||
|
|
||||||
return result_item
|
return result_item
|
||||||
|
|
||||||
async def get_all(
|
async def get_all(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
agent_id: Optional[str] = None,
|
agent_id: Optional[str] = None,
|
||||||
run_id: Optional[str] = None,
|
run_id: Optional[str] = None,
|
||||||
@@ -1314,41 +1318,36 @@ class AsyncMemory(MemoryBase):
|
|||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
List all memories.
|
List all memories.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str, optional): user id
|
user_id (str, optional): user id
|
||||||
agent_id (str, optional): agent id
|
agent_id (str, optional): agent id
|
||||||
run_id (str, optional): run id
|
run_id (str, optional): run id
|
||||||
filters (dict, optional): Additional custom key-value filters to apply to the search.
|
filters (dict, optional): Additional custom key-value filters to apply to the search.
|
||||||
These are merged with the ID-based scoping filters. For example,
|
These are merged with the ID-based scoping filters. For example,
|
||||||
`filters={"actor_id": "some_user"}`.
|
`filters={"actor_id": "some_user"}`.
|
||||||
limit (int, optional): The maximum number of memories to return. Defaults to 100.
|
limit (int, optional): The maximum number of memories to return. Defaults to 100.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary containing a list of memories under the "results" key,
|
dict: A dictionary containing a list of memories under the "results" key,
|
||||||
and potentially "relations" if graph store is enabled. For API v1.0,
|
and potentially "relations" if graph store is enabled. For API v1.0,
|
||||||
it might return a direct list (see deprecation warning).
|
it might return a direct list (see deprecation warning).
|
||||||
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", ...}]}`
|
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", ...}]}`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_, effective_filters = _build_filters_and_metadata(
|
_, effective_filters = _build_filters_and_metadata(
|
||||||
user_id=user_id,
|
user_id=user_id, agent_id=agent_id, run_id=run_id, input_filters=filters
|
||||||
agent_id=agent_id,
|
|
||||||
run_id=run_id,
|
|
||||||
input_filters=filters
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
|
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"When 'conversation_id' is not provided (classic mode), "
|
"When 'conversation_id' is not provided (classic mode), "
|
||||||
"at least one of 'user_id', 'agent_id', or 'run_id' must be specified for get_all."
|
"at least one of 'user_id', 'agent_id', or 'run_id' must be specified for get_all."
|
||||||
)
|
)
|
||||||
|
|
||||||
capture_event(
|
capture_event(
|
||||||
"mem0.get_all",
|
"mem0.get_all", self, {"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "async"}
|
||||||
self,
|
|
||||||
{"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "async"}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
@@ -1361,9 +1360,9 @@ class AsyncMemory(MemoryBase):
|
|||||||
[future_memories, future_graph_entities] if future_graph_entities else [future_memories]
|
[future_memories, future_graph_entities] if future_graph_entities else [future_memories]
|
||||||
)
|
)
|
||||||
|
|
||||||
all_memories_result = future_memories.result()
|
all_memories_result = future_memories.result()
|
||||||
graph_entities_result = future_graph_entities.result() if future_graph_entities else None
|
graph_entities_result = future_graph_entities.result() if future_graph_entities else None
|
||||||
|
|
||||||
if self.enable_graph:
|
if self.enable_graph:
|
||||||
return {"results": all_memories_result, "relations": graph_entities_result}
|
return {"results": all_memories_result, "relations": graph_entities_result}
|
||||||
|
|
||||||
@@ -1381,20 +1380,21 @@ class AsyncMemory(MemoryBase):
|
|||||||
|
|
||||||
async def _get_all_from_vector_store(self, filters, limit):
|
async def _get_all_from_vector_store(self, filters, limit):
|
||||||
memories_result = await asyncio.to_thread(self.vector_store.list, filters=filters, limit=limit)
|
memories_result = await asyncio.to_thread(self.vector_store.list, filters=filters, limit=limit)
|
||||||
actual_memories = memories_result[0] if isinstance(memories_result, tuple) and len(memories_result) > 0 else memories_result
|
actual_memories = (
|
||||||
|
memories_result[0] if isinstance(memories_result, tuple) and len(memories_result) > 0 else memories_result
|
||||||
|
)
|
||||||
|
|
||||||
promoted_payload_keys = [
|
promoted_payload_keys = [
|
||||||
"user_id", "agent_id", "run_id",
|
"user_id",
|
||||||
|
"agent_id",
|
||||||
|
"run_id",
|
||||||
"actor_id",
|
"actor_id",
|
||||||
"role",
|
"role",
|
||||||
]
|
]
|
||||||
core_and_promoted_keys = {
|
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
|
||||||
"data", "hash", "created_at", "updated_at", "id",
|
|
||||||
*promoted_payload_keys
|
|
||||||
}
|
|
||||||
|
|
||||||
formatted_memories = []
|
formatted_memories = []
|
||||||
for mem in actual_memories:
|
for mem in actual_memories:
|
||||||
memory_item_dict = MemoryItem(
|
memory_item_dict = MemoryItem(
|
||||||
id=mem.id,
|
id=mem.id,
|
||||||
memory=mem.payload["data"],
|
memory=mem.payload["data"],
|
||||||
@@ -1406,15 +1406,13 @@ class AsyncMemory(MemoryBase):
|
|||||||
for key in promoted_payload_keys:
|
for key in promoted_payload_keys:
|
||||||
if key in mem.payload:
|
if key in mem.payload:
|
||||||
memory_item_dict[key] = mem.payload[key]
|
memory_item_dict[key] = mem.payload[key]
|
||||||
|
|
||||||
additional_metadata = {
|
additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys}
|
||||||
k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys
|
|
||||||
}
|
|
||||||
if additional_metadata:
|
if additional_metadata:
|
||||||
memory_item_dict["metadata"] = additional_metadata
|
memory_item_dict["metadata"] = additional_metadata
|
||||||
|
|
||||||
formatted_memories.append(memory_item_dict)
|
formatted_memories.append(memory_item_dict)
|
||||||
|
|
||||||
return formatted_memories
|
return formatted_memories
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
@@ -1442,16 +1440,13 @@ class AsyncMemory(MemoryBase):
|
|||||||
and potentially "relations" if graph store is enabled.
|
and potentially "relations" if graph store is enabled.
|
||||||
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "score": 0.8, ...}]}`
|
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "score": 0.8, ...}]}`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_, effective_filters = _build_filters_and_metadata(
|
_, effective_filters = _build_filters_and_metadata(
|
||||||
user_id=user_id,
|
user_id=user_id, agent_id=agent_id, run_id=run_id, input_filters=filters
|
||||||
agent_id=agent_id,
|
|
||||||
run_id=run_id,
|
|
||||||
input_filters=filters
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
|
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
|
||||||
raise ValueError("at least one of 'user_id', 'agent_id', or 'run_id' must be specified ")
|
raise ValueError("at least one of 'user_id', 'agent_id', or 'run_id' must be specified ")
|
||||||
|
|
||||||
capture_event(
|
capture_event(
|
||||||
"mem0.search",
|
"mem0.search",
|
||||||
@@ -1460,22 +1455,20 @@ class AsyncMemory(MemoryBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
vector_store_task = asyncio.create_task(self._search_vector_store(query, effective_filters, limit))
|
vector_store_task = asyncio.create_task(self._search_vector_store(query, effective_filters, limit))
|
||||||
|
|
||||||
graph_task = None
|
graph_task = None
|
||||||
if self.enable_graph:
|
if self.enable_graph:
|
||||||
if hasattr(self.graph.search, "__await__"): # Check if graph search is async
|
if hasattr(self.graph.search, "__await__"): # Check if graph search is async
|
||||||
graph_task = asyncio.create_task(self.graph.search(query, effective_filters, limit))
|
graph_task = asyncio.create_task(self.graph.search(query, effective_filters, limit))
|
||||||
else:
|
else:
|
||||||
graph_task = asyncio.create_task(
|
graph_task = asyncio.create_task(asyncio.to_thread(self.graph.search, query, effective_filters, limit))
|
||||||
asyncio.to_thread(self.graph.search, query, effective_filters, limit)
|
|
||||||
)
|
|
||||||
|
|
||||||
if graph_task:
|
if graph_task:
|
||||||
original_memories, graph_entities = await asyncio.gather(vector_store_task, graph_task)
|
original_memories, graph_entities = await asyncio.gather(vector_store_task, graph_task)
|
||||||
else:
|
else:
|
||||||
original_memories = await vector_store_task
|
original_memories = await vector_store_task
|
||||||
graph_entities = None
|
graph_entities = None
|
||||||
|
|
||||||
if self.enable_graph:
|
if self.enable_graph:
|
||||||
return {"results": original_memories, "relations": graph_entities}
|
return {"results": original_memories, "relations": graph_entities}
|
||||||
|
|
||||||
@@ -1504,11 +1497,8 @@ class AsyncMemory(MemoryBase):
|
|||||||
"actor_id",
|
"actor_id",
|
||||||
"role",
|
"role",
|
||||||
]
|
]
|
||||||
|
|
||||||
core_and_promoted_keys = {
|
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
|
||||||
"data", "hash", "created_at", "updated_at", "id",
|
|
||||||
*promoted_payload_keys
|
|
||||||
}
|
|
||||||
|
|
||||||
original_memories = []
|
original_memories = []
|
||||||
for mem in memories:
|
for mem in memories:
|
||||||
@@ -1518,19 +1508,17 @@ class AsyncMemory(MemoryBase):
|
|||||||
hash=mem.payload.get("hash"),
|
hash=mem.payload.get("hash"),
|
||||||
created_at=mem.payload.get("created_at"),
|
created_at=mem.payload.get("created_at"),
|
||||||
updated_at=mem.payload.get("updated_at"),
|
updated_at=mem.payload.get("updated_at"),
|
||||||
score=mem.score,
|
score=mem.score,
|
||||||
).model_dump()
|
).model_dump()
|
||||||
|
|
||||||
for key in promoted_payload_keys:
|
for key in promoted_payload_keys:
|
||||||
if key in mem.payload:
|
if key in mem.payload:
|
||||||
memory_item_dict[key] = mem.payload[key]
|
memory_item_dict[key] = mem.payload[key]
|
||||||
|
|
||||||
additional_metadata = {
|
additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys}
|
||||||
k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys
|
|
||||||
}
|
|
||||||
if additional_metadata:
|
if additional_metadata:
|
||||||
memory_item_dict["metadata"] = additional_metadata
|
memory_item_dict["metadata"] = additional_metadata
|
||||||
|
|
||||||
original_memories.append(memory_item_dict)
|
original_memories.append(memory_item_dict)
|
||||||
|
|
||||||
return original_memories
|
return original_memories
|
||||||
@@ -1650,7 +1638,7 @@ class AsyncMemory(MemoryBase):
|
|||||||
capture_event("mem0._create_memory", self, {"memory_id": memory_id, "sync_type": "async"})
|
capture_event("mem0._create_memory", self, {"memory_id": memory_id, "sync_type": "async"})
|
||||||
return memory_id
|
return memory_id
|
||||||
|
|
||||||
async def _create_procedural_memory(self, messages, metadata=None,llm=None ,prompt=None):
|
async def _create_procedural_memory(self, messages, metadata=None, llm=None, prompt=None):
|
||||||
"""
|
"""
|
||||||
Create a procedural memory asynchronously
|
Create a procedural memory asynchronously
|
||||||
|
|
||||||
@@ -1709,11 +1697,11 @@ class AsyncMemory(MemoryBase):
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.error(f"Error getting memory with ID {memory_id} during update.")
|
logger.error(f"Error getting memory with ID {memory_id} during update.")
|
||||||
raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'")
|
raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'")
|
||||||
|
|
||||||
prev_value = existing_memory.payload.get("data")
|
prev_value = existing_memory.payload.get("data")
|
||||||
|
|
||||||
new_metadata = deepcopy(metadata) if metadata is not None else {}
|
new_metadata = deepcopy(metadata) if metadata is not None else {}
|
||||||
|
|
||||||
new_metadata["data"] = data
|
new_metadata["data"] = data
|
||||||
new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
|
new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
|
||||||
new_metadata["created_at"] = existing_memory.payload.get("created_at")
|
new_metadata["created_at"] = existing_memory.payload.get("created_at")
|
||||||
@@ -1725,8 +1713,7 @@ class AsyncMemory(MemoryBase):
|
|||||||
new_metadata["agent_id"] = existing_memory.payload["agent_id"]
|
new_metadata["agent_id"] = existing_memory.payload["agent_id"]
|
||||||
if "run_id" in existing_memory.payload:
|
if "run_id" in existing_memory.payload:
|
||||||
new_metadata["run_id"] = existing_memory.payload["run_id"]
|
new_metadata["run_id"] = existing_memory.payload["run_id"]
|
||||||
|
|
||||||
|
|
||||||
if "actor_id" in existing_memory.payload:
|
if "actor_id" in existing_memory.payload:
|
||||||
new_metadata["actor_id"] = existing_memory.payload["actor_id"]
|
new_metadata["actor_id"] = existing_memory.payload["actor_id"]
|
||||||
if "role" in existing_memory.payload:
|
if "role" in existing_memory.payload:
|
||||||
@@ -1736,7 +1723,7 @@ class AsyncMemory(MemoryBase):
|
|||||||
embeddings = existing_embeddings[data]
|
embeddings = existing_embeddings[data]
|
||||||
else:
|
else:
|
||||||
embeddings = await asyncio.to_thread(self.embedding_model.embed, data, "update")
|
embeddings = await asyncio.to_thread(self.embedding_model.embed, data, "update")
|
||||||
|
|
||||||
await asyncio.to_thread(
|
await asyncio.to_thread(
|
||||||
self.vector_store.update,
|
self.vector_store.update,
|
||||||
vector_id=memory_id,
|
vector_id=memory_id,
|
||||||
@@ -1744,7 +1731,7 @@ class AsyncMemory(MemoryBase):
|
|||||||
payload=new_metadata,
|
payload=new_metadata,
|
||||||
)
|
)
|
||||||
logger.info(f"Updating memory with ID {memory_id=} with {data=}")
|
logger.info(f"Updating memory with ID {memory_id=} with {data=}")
|
||||||
|
|
||||||
await asyncio.to_thread(
|
await asyncio.to_thread(
|
||||||
self.db.add_history,
|
self.db.add_history,
|
||||||
memory_id,
|
memory_id,
|
||||||
|
|||||||
@@ -5,16 +5,12 @@ from mem0.memory.utils import format_entities
|
|||||||
try:
|
try:
|
||||||
from langchain_memgraph import Memgraph
|
from langchain_memgraph import Memgraph
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError("langchain_memgraph is not installed. Please install it using pip install langchain-memgraph")
|
||||||
"langchain_memgraph is not installed. Please install it using pip install langchain-memgraph"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from rank_bm25 import BM25Okapi
|
from rank_bm25 import BM25Okapi
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
|
||||||
"rank_bm25 is not installed. Please install it using pip install rank-bm25"
|
|
||||||
)
|
|
||||||
|
|
||||||
from mem0.graphs.tools import (
|
from mem0.graphs.tools import (
|
||||||
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
|
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
|
||||||
@@ -74,22 +70,14 @@ class MemoryGraph:
|
|||||||
filters (dict): A dictionary containing filters to be applied during the addition.
|
filters (dict): A dictionary containing filters to be applied during the addition.
|
||||||
"""
|
"""
|
||||||
entity_type_map = self._retrieve_nodes_from_data(data, filters)
|
entity_type_map = self._retrieve_nodes_from_data(data, filters)
|
||||||
to_be_added = self._establish_nodes_relations_from_data(
|
to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map)
|
||||||
data, filters, entity_type_map
|
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
|
||||||
)
|
to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters)
|
||||||
search_output = self._search_graph_db(
|
|
||||||
node_list=list(entity_type_map.keys()), filters=filters
|
|
||||||
)
|
|
||||||
to_be_deleted = self._get_delete_entities_from_search_output(
|
|
||||||
search_output, data, filters
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: Batch queries with APOC plugin
|
# TODO: Batch queries with APOC plugin
|
||||||
# TODO: Add more filter support
|
# TODO: Add more filter support
|
||||||
deleted_entities = self._delete_entities(to_be_deleted, filters["user_id"])
|
deleted_entities = self._delete_entities(to_be_deleted, filters["user_id"])
|
||||||
added_entities = self._add_entities(
|
added_entities = self._add_entities(to_be_added, filters["user_id"], entity_type_map)
|
||||||
to_be_added, filters["user_id"], entity_type_map
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"deleted_entities": deleted_entities, "added_entities": added_entities}
|
return {"deleted_entities": deleted_entities, "added_entities": added_entities}
|
||||||
|
|
||||||
@@ -108,16 +96,13 @@ class MemoryGraph:
|
|||||||
- "entities": List of related graph data based on the query.
|
- "entities": List of related graph data based on the query.
|
||||||
"""
|
"""
|
||||||
entity_type_map = self._retrieve_nodes_from_data(query, filters)
|
entity_type_map = self._retrieve_nodes_from_data(query, filters)
|
||||||
search_output = self._search_graph_db(
|
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
|
||||||
node_list=list(entity_type_map.keys()), filters=filters
|
|
||||||
)
|
|
||||||
|
|
||||||
if not search_output:
|
if not search_output:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
search_outputs_sequence = [
|
search_outputs_sequence = [
|
||||||
[item["source"], item["relationship"], item["destination"]]
|
[item["source"], item["relationship"], item["destination"]] for item in search_output
|
||||||
for item in search_output
|
|
||||||
]
|
]
|
||||||
bm25 = BM25Okapi(search_outputs_sequence)
|
bm25 = BM25Okapi(search_outputs_sequence)
|
||||||
|
|
||||||
@@ -126,9 +111,7 @@ class MemoryGraph:
|
|||||||
|
|
||||||
search_results = []
|
search_results = []
|
||||||
for item in reranked_results:
|
for item in reranked_results:
|
||||||
search_results.append(
|
search_results.append({"source": item[0], "relationship": item[1], "destination": item[2]})
|
||||||
{"source": item[0], "relationship": item[1], "destination": item[2]}
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Returned {len(search_results)} search results")
|
logger.info(f"Returned {len(search_results)} search results")
|
||||||
|
|
||||||
@@ -161,9 +144,7 @@ class MemoryGraph:
|
|||||||
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
results = self.graph.query(
|
results = self.graph.query(query, params={"user_id": filters["user_id"], "limit": limit})
|
||||||
query, params={"user_id": filters["user_id"], "limit": limit}
|
|
||||||
)
|
|
||||||
|
|
||||||
final_results = []
|
final_results = []
|
||||||
for result in results:
|
for result in results:
|
||||||
@@ -208,13 +189,8 @@ class MemoryGraph:
|
|||||||
f"Error in search tool: {e}, llm_provider={self.llm_provider}, search_results={search_results}"
|
f"Error in search tool: {e}, llm_provider={self.llm_provider}, search_results={search_results}"
|
||||||
)
|
)
|
||||||
|
|
||||||
entity_type_map = {
|
entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()}
|
||||||
k.lower().replace(" ", "_"): v.lower().replace(" ", "_")
|
logger.debug(f"Entity type map: {entity_type_map}\n search_results={search_results}")
|
||||||
for k, v in entity_type_map.items()
|
|
||||||
}
|
|
||||||
logger.debug(
|
|
||||||
f"Entity type map: {entity_type_map}\n search_results={search_results}"
|
|
||||||
)
|
|
||||||
return entity_type_map
|
return entity_type_map
|
||||||
|
|
||||||
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
|
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
|
||||||
@@ -223,9 +199,7 @@ class MemoryGraph:
|
|||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": EXTRACT_RELATIONS_PROMPT.replace(
|
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]).replace(
|
||||||
"USER_ID", filters["user_id"]
|
|
||||||
).replace(
|
|
||||||
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
|
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
@@ -235,9 +209,7 @@ class MemoryGraph:
|
|||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": EXTRACT_RELATIONS_PROMPT.replace(
|
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]),
|
||||||
"USER_ID", filters["user_id"]
|
|
||||||
),
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
@@ -304,9 +276,7 @@ class MemoryGraph:
|
|||||||
def _get_delete_entities_from_search_output(self, search_output, data, filters):
|
def _get_delete_entities_from_search_output(self, search_output, data, filters):
|
||||||
"""Get the entities to be deleted from the search output."""
|
"""Get the entities to be deleted from the search output."""
|
||||||
search_output_string = format_entities(search_output)
|
search_output_string = format_entities(search_output)
|
||||||
system_prompt, user_prompt = get_delete_messages(
|
system_prompt, user_prompt = get_delete_messages(search_output_string, data, filters["user_id"])
|
||||||
search_output_string, data, filters["user_id"]
|
|
||||||
)
|
|
||||||
|
|
||||||
_tools = [DELETE_MEMORY_TOOL_GRAPH]
|
_tools = [DELETE_MEMORY_TOOL_GRAPH]
|
||||||
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
||||||
@@ -379,12 +349,8 @@ class MemoryGraph:
|
|||||||
# search for the nodes with the closest embeddings; this is basically
|
# search for the nodes with the closest embeddings; this is basically
|
||||||
# comparison of one embedding to all embeddings in a graph -> vector
|
# comparison of one embedding to all embeddings in a graph -> vector
|
||||||
# search with cosine similarity metric
|
# search with cosine similarity metric
|
||||||
source_node_search_result = self._search_source_node(
|
source_node_search_result = self._search_source_node(source_embedding, user_id, threshold=0.9)
|
||||||
source_embedding, user_id, threshold=0.9
|
destination_node_search_result = self._search_destination_node(dest_embedding, user_id, threshold=0.9)
|
||||||
)
|
|
||||||
destination_node_search_result = self._search_destination_node(
|
|
||||||
dest_embedding, user_id, threshold=0.9
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: Create a cypher query and common params for all the cases
|
# TODO: Create a cypher query and common params for all the cases
|
||||||
if not destination_node_search_result and source_node_search_result:
|
if not destination_node_search_result and source_node_search_result:
|
||||||
@@ -424,9 +390,7 @@ class MemoryGraph:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"destination_id": destination_node_search_result[0][
|
"destination_id": destination_node_search_result[0]["id(destination_candidate)"],
|
||||||
"id(destination_candidate)"
|
|
||||||
],
|
|
||||||
"source_name": source,
|
"source_name": source,
|
||||||
"source_embedding": source_embedding,
|
"source_embedding": source_embedding,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
@@ -445,9 +409,7 @@ class MemoryGraph:
|
|||||||
"""
|
"""
|
||||||
params = {
|
params = {
|
||||||
"source_id": source_node_search_result[0]["id(source_candidate)"],
|
"source_id": source_node_search_result[0]["id(source_candidate)"],
|
||||||
"destination_id": destination_node_search_result[0][
|
"destination_id": destination_node_search_result[0]["id(destination_candidate)"],
|
||||||
"id(destination_candidate)"
|
|
||||||
],
|
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
|
import logging
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
import logging
|
from typing import Any, Dict, List, Optional
|
||||||
from typing import List, Dict, Any, Optional
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -23,9 +23,7 @@ class SQLiteManager:
|
|||||||
"""
|
"""
|
||||||
with self._lock, self.connection:
|
with self._lock, self.connection:
|
||||||
cur = self.connection.cursor()
|
cur = self.connection.cursor()
|
||||||
cur.execute(
|
cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='history'")
|
||||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='history'"
|
|
||||||
)
|
|
||||||
if cur.fetchone() is None:
|
if cur.fetchone() is None:
|
||||||
return # nothing to migrate
|
return # nothing to migrate
|
||||||
|
|
||||||
@@ -51,13 +49,11 @@ class SQLiteManager:
|
|||||||
logger.info("Migrating history table to new schema (no convo columns).")
|
logger.info("Migrating history table to new schema (no convo columns).")
|
||||||
cur.execute("ALTER TABLE history RENAME TO history_old")
|
cur.execute("ALTER TABLE history RENAME TO history_old")
|
||||||
|
|
||||||
self._create_history_table()
|
self._create_history_table()
|
||||||
|
|
||||||
intersecting = list(expected_cols & old_cols)
|
intersecting = list(expected_cols & old_cols)
|
||||||
cols_csv = ", ".join(intersecting)
|
cols_csv = ", ".join(intersecting)
|
||||||
cur.execute(
|
cur.execute(f"INSERT INTO history ({cols_csv}) SELECT {cols_csv} FROM history_old")
|
||||||
f"INSERT INTO history ({cols_csv}) SELECT {cols_csv} FROM history_old"
|
|
||||||
)
|
|
||||||
cur.execute("DROP TABLE history_old")
|
cur.execute("DROP TABLE history_old")
|
||||||
|
|
||||||
def _create_history_table(self) -> None:
|
def _create_history_table(self) -> None:
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import mem0
|
|||||||
from mem0.memory.setup import get_or_create_user_id
|
from mem0.memory.setup import get_or_create_user_id
|
||||||
|
|
||||||
MEM0_TELEMETRY = os.environ.get("MEM0_TELEMETRY", "True")
|
MEM0_TELEMETRY = os.environ.get("MEM0_TELEMETRY", "True")
|
||||||
PROJECT_API_KEY="phc_hgJkUVJFYtmaJqrvf6CYN67TIQ8yhXAkWzUn9AMU4yX"
|
PROJECT_API_KEY = "phc_hgJkUVJFYtmaJqrvf6CYN67TIQ8yhXAkWzUn9AMU4yX"
|
||||||
HOST="https://us.i.posthog.com"
|
HOST = "https://us.i.posthog.com"
|
||||||
|
|
||||||
if isinstance(MEM0_TELEMETRY, str):
|
if isinstance(MEM0_TELEMETRY, str):
|
||||||
MEM0_TELEMETRY = MEM0_TELEMETRY.lower() in ("true", "1", "yes")
|
MEM0_TELEMETRY = MEM0_TELEMETRY.lower() in ("true", "1", "yes")
|
||||||
|
|||||||
@@ -98,9 +98,8 @@ class VectorStoreFactory:
|
|||||||
return vector_store_instance(**config)
|
return vector_store_instance(**config)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported VectorStore provider: {provider_name}")
|
raise ValueError(f"Unsupported VectorStore provider: {provider_name}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def reset(cls, instance):
|
def reset(cls, instance):
|
||||||
instance.reset()
|
instance.reset()
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
|
|||||||
@@ -377,4 +377,3 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error resetting index {self.index_name}: {e}")
|
logger.error(f"Error resetting index {self.index_name}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class VectorStoreBase(ABC):
|
|||||||
def list(self, filters=None, limit=None):
|
def list(self, filters=None, limit=None):
|
||||||
"""List all memories."""
|
"""List all memories."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset by delete the collection and recreate it."""
|
"""Reset by delete the collection and recreate it."""
|
||||||
|
|||||||
@@ -221,7 +221,7 @@ class ChromaDB(VectorStoreBase):
|
|||||||
"""
|
"""
|
||||||
results = self.collection.get(where=filters, limit=limit)
|
results = self.collection.get(where=filters, limit=limit)
|
||||||
return [self._parse_output(results)]
|
return [self._parse_output(results)]
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset the index by deleting and recreating it."""
|
"""Reset the index by deleting and recreating it."""
|
||||||
logger.warning(f"Resetting index {self.collection_name}...")
|
logger.warning(f"Resetting index {self.collection_name}...")
|
||||||
|
|||||||
@@ -58,7 +58,12 @@ class ElasticsearchDB(VectorStoreBase):
|
|||||||
"mappings": {
|
"mappings": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"text": {"type": "text"},
|
"text": {"type": "text"},
|
||||||
"vector": {"type": "dense_vector", "dims": self.embedding_model_dims, "index": True, "similarity": "cosine"},
|
"vector": {
|
||||||
|
"type": "dense_vector",
|
||||||
|
"dims": self.embedding_model_dims,
|
||||||
|
"index": True,
|
||||||
|
"similarity": "cosine",
|
||||||
|
},
|
||||||
"metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}},
|
"metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -222,7 +227,7 @@ class ElasticsearchDB(VectorStoreBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return [results]
|
return [results]
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset the index by deleting and recreating it."""
|
"""Reset the index by deleting and recreating it."""
|
||||||
logger.warning(f"Resetting index {self.collection_name}...")
|
logger.warning(f"Resetting index {self.collection_name}...")
|
||||||
|
|||||||
@@ -465,7 +465,7 @@ class FAISS(VectorStoreBase):
|
|||||||
break
|
break
|
||||||
|
|
||||||
return [results]
|
return [results]
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset the index by deleting and recreating it."""
|
"""Reset the index by deleting and recreating it."""
|
||||||
logger.warning(f"Resetting index {self.collection_name}...")
|
logger.warning(f"Resetting index {self.collection_name}...")
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from mem0.vector_stores.base import VectorStoreBase
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OutputData(BaseModel):
|
class OutputData(BaseModel):
|
||||||
id: Optional[str] # memory id
|
id: Optional[str] # memory id
|
||||||
score: Optional[float] # distance
|
score: Optional[float] # distance
|
||||||
@@ -162,10 +163,7 @@ class Langchain(VectorStoreBase):
|
|||||||
if filters and "user_id" in filters:
|
if filters and "user_id" in filters:
|
||||||
where_clause = {"user_id": filters["user_id"]}
|
where_clause = {"user_id": filters["user_id"]}
|
||||||
|
|
||||||
result = self.client._collection.get(
|
result = self.client._collection.get(where=where_clause, limit=limit)
|
||||||
where=where_clause,
|
|
||||||
limit=limit
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert the result to the expected format
|
# Convert the result to the expected format
|
||||||
if result and isinstance(result, dict):
|
if result and isinstance(result, dict):
|
||||||
|
|||||||
@@ -237,7 +237,7 @@ class MilvusDB(VectorStoreBase):
|
|||||||
obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata"))
|
obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata"))
|
||||||
memories.append(obj)
|
memories.append(obj)
|
||||||
return [memories]
|
return [memories]
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset the index by deleting and recreating it."""
|
"""Reset the index by deleting and recreating it."""
|
||||||
logger.warning(f"Resetting index {self.collection_name}...")
|
logger.warning(f"Resetting index {self.collection_name}...")
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
import time
|
import time
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from opensearchpy import OpenSearch, RequestsHttpConnection
|
from opensearchpy import OpenSearch, RequestsHttpConnection
|
||||||
@@ -34,7 +34,7 @@ class OpenSearchDB(VectorStoreBase):
|
|||||||
use_ssl=config.use_ssl,
|
use_ssl=config.use_ssl,
|
||||||
verify_certs=config.verify_certs,
|
verify_certs=config.verify_certs,
|
||||||
connection_class=RequestsHttpConnection,
|
connection_class=RequestsHttpConnection,
|
||||||
pool_maxsize=20
|
pool_maxsize=20,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.collection_name = config.collection_name
|
self.collection_name = config.collection_name
|
||||||
@@ -69,9 +69,7 @@ class OpenSearchDB(VectorStoreBase):
|
|||||||
def create_col(self, name: str, vector_size: int) -> None:
|
def create_col(self, name: str, vector_size: int) -> None:
|
||||||
"""Create a new collection (index in OpenSearch)."""
|
"""Create a new collection (index in OpenSearch)."""
|
||||||
index_settings = {
|
index_settings = {
|
||||||
"settings": {
|
"settings": {"index.knn": True},
|
||||||
"index.knn": True
|
|
||||||
},
|
|
||||||
"mappings": {
|
"mappings": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"vector_field": {
|
"vector_field": {
|
||||||
@@ -82,7 +80,7 @@ class OpenSearchDB(VectorStoreBase):
|
|||||||
"payload": {"type": "object"},
|
"payload": {"type": "object"},
|
||||||
"id": {"type": "keyword"},
|
"id": {"type": "keyword"},
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if not self.client.indices.exists(index=name):
|
if not self.client.indices.exists(index=name):
|
||||||
@@ -102,9 +100,7 @@ class OpenSearchDB(VectorStoreBase):
|
|||||||
except Exception:
|
except Exception:
|
||||||
retry_count += 1
|
retry_count += 1
|
||||||
if retry_count == max_retries:
|
if retry_count == max_retries:
|
||||||
raise TimeoutError(
|
raise TimeoutError(f"Index {name} creation timed out after {max_retries} seconds")
|
||||||
f"Index {name} creation timed out after {max_retries} seconds"
|
|
||||||
)
|
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
|
|
||||||
def insert(
|
def insert(
|
||||||
@@ -145,10 +141,7 @@ class OpenSearchDB(VectorStoreBase):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Start building the full query
|
# Start building the full query
|
||||||
query_body = {
|
query_body = {"size": limit * 2, "query": None}
|
||||||
"size": limit * 2,
|
|
||||||
"query": None
|
|
||||||
}
|
|
||||||
|
|
||||||
# Prepare filter conditions if applicable
|
# Prepare filter conditions if applicable
|
||||||
filter_clauses = []
|
filter_clauses = []
|
||||||
@@ -156,18 +149,11 @@ class OpenSearchDB(VectorStoreBase):
|
|||||||
for key in ["user_id", "run_id", "agent_id"]:
|
for key in ["user_id", "run_id", "agent_id"]:
|
||||||
value = filters.get(key)
|
value = filters.get(key)
|
||||||
if value:
|
if value:
|
||||||
filter_clauses.append({
|
filter_clauses.append({"term": {f"payload.{key}.keyword": value}})
|
||||||
"term": {f"payload.{key}.keyword": value}
|
|
||||||
})
|
|
||||||
|
|
||||||
# Combine knn with filters if needed
|
# Combine knn with filters if needed
|
||||||
if filter_clauses:
|
if filter_clauses:
|
||||||
query_body["query"] = {
|
query_body["query"] = {"bool": {"must": knn_query, "filter": filter_clauses}}
|
||||||
"bool": {
|
|
||||||
"must": knn_query,
|
|
||||||
"filter": filter_clauses
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
query_body["query"] = knn_query
|
query_body["query"] = knn_query
|
||||||
|
|
||||||
@@ -176,11 +162,7 @@ class OpenSearchDB(VectorStoreBase):
|
|||||||
|
|
||||||
hits = response["hits"]["hits"]
|
hits = response["hits"]["hits"]
|
||||||
results = [
|
results = [
|
||||||
OutputData(
|
OutputData(id=hit["_source"].get("id"), score=hit["_score"], payload=hit["_source"].get("payload", {}))
|
||||||
id=hit["_source"].get("id"),
|
|
||||||
score=hit["_score"],
|
|
||||||
payload=hit["_source"].get("payload", {})
|
|
||||||
)
|
|
||||||
for hit in hits
|
for hit in hits
|
||||||
]
|
]
|
||||||
return results
|
return results
|
||||||
@@ -188,13 +170,7 @@ class OpenSearchDB(VectorStoreBase):
|
|||||||
def delete(self, vector_id: str) -> None:
|
def delete(self, vector_id: str) -> None:
|
||||||
"""Delete a vector by custom ID."""
|
"""Delete a vector by custom ID."""
|
||||||
# First, find the document by custom ID
|
# First, find the document by custom ID
|
||||||
search_query = {
|
search_query = {"query": {"term": {"id": vector_id}}}
|
||||||
"query": {
|
|
||||||
"term": {
|
|
||||||
"id": vector_id
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
response = self.client.search(index=self.collection_name, body=search_query)
|
response = self.client.search(index=self.collection_name, body=search_query)
|
||||||
hits = response.get("hits", {}).get("hits", [])
|
hits = response.get("hits", {}).get("hits", [])
|
||||||
@@ -207,18 +183,11 @@ class OpenSearchDB(VectorStoreBase):
|
|||||||
# Delete using the actual document ID
|
# Delete using the actual document ID
|
||||||
self.client.delete(index=self.collection_name, id=opensearch_id)
|
self.client.delete(index=self.collection_name, id=opensearch_id)
|
||||||
|
|
||||||
|
|
||||||
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None:
|
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None:
|
||||||
"""Update a vector and its payload using the custom 'id' field."""
|
"""Update a vector and its payload using the custom 'id' field."""
|
||||||
|
|
||||||
# First, find the document by custom ID
|
# First, find the document by custom ID
|
||||||
search_query = {
|
search_query = {"query": {"term": {"id": vector_id}}}
|
||||||
"query": {
|
|
||||||
"term": {
|
|
||||||
"id": vector_id
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
response = self.client.search(index=self.collection_name, body=search_query)
|
response = self.client.search(index=self.collection_name, body=search_query)
|
||||||
hits = response.get("hits", {}).get("hits", [])
|
hits = response.get("hits", {}).get("hits", [])
|
||||||
@@ -241,7 +210,6 @@ class OpenSearchDB(VectorStoreBase):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get(self, vector_id: str) -> Optional[OutputData]:
|
def get(self, vector_id: str) -> Optional[OutputData]:
|
||||||
"""Retrieve a vector by ID."""
|
"""Retrieve a vector by ID."""
|
||||||
try:
|
try:
|
||||||
@@ -251,13 +219,7 @@ class OpenSearchDB(VectorStoreBase):
|
|||||||
self.create_col(self.collection_name, self.embedding_model_dims)
|
self.create_col(self.collection_name, self.embedding_model_dims)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
search_query = {
|
search_query = {"query": {"term": {"id": vector_id}}}
|
||||||
"query": {
|
|
||||||
"term": {
|
|
||||||
"id": vector_id
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
response = self.client.search(index=self.collection_name, body=search_query)
|
response = self.client.search(index=self.collection_name, body=search_query)
|
||||||
|
|
||||||
hits = response["hits"]["hits"]
|
hits = response["hits"]["hits"]
|
||||||
@@ -265,11 +227,7 @@ class OpenSearchDB(VectorStoreBase):
|
|||||||
if not hits:
|
if not hits:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return OutputData(
|
return OutputData(id=hits[0]["_source"].get("id"), score=1.0, payload=hits[0]["_source"].get("payload", {}))
|
||||||
id=hits[0]["_source"].get("id"),
|
|
||||||
score=1.0,
|
|
||||||
payload=hits[0]["_source"].get("payload", {})
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error retrieving vector {vector_id}: {str(e)}")
|
logger.error(f"Error retrieving vector {vector_id}: {str(e)}")
|
||||||
return None
|
return None
|
||||||
@@ -287,30 +245,19 @@ class OpenSearchDB(VectorStoreBase):
|
|||||||
return self.client.indices.get(index=name)
|
return self.client.indices.get(index=name)
|
||||||
|
|
||||||
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[OutputData]:
|
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[OutputData]:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
"""List all memories with optional filters."""
|
"""List all memories with optional filters."""
|
||||||
query: Dict = {
|
query: Dict = {"query": {"match_all": {}}}
|
||||||
"query": {
|
|
||||||
"match_all": {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
filter_clauses = []
|
filter_clauses = []
|
||||||
if filters:
|
if filters:
|
||||||
for key in ["user_id", "run_id", "agent_id"]:
|
for key in ["user_id", "run_id", "agent_id"]:
|
||||||
value = filters.get(key)
|
value = filters.get(key)
|
||||||
if value:
|
if value:
|
||||||
filter_clauses.append({
|
filter_clauses.append({"term": {f"payload.{key}.keyword": value}})
|
||||||
"term": {f"payload.{key}.keyword": value}
|
|
||||||
})
|
|
||||||
|
|
||||||
if filter_clauses:
|
if filter_clauses:
|
||||||
query["query"] = {
|
query["query"] = {"bool": {"filter": filter_clauses}}
|
||||||
"bool": {
|
|
||||||
"filter": filter_clauses
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if limit:
|
if limit:
|
||||||
query["size"] = limit
|
query["size"] = limit
|
||||||
@@ -318,18 +265,15 @@ class OpenSearchDB(VectorStoreBase):
|
|||||||
response = self.client.search(index=self.collection_name, body=query)
|
response = self.client.search(index=self.collection_name, body=query)
|
||||||
hits = response["hits"]["hits"]
|
hits = response["hits"]["hits"]
|
||||||
|
|
||||||
return [[
|
return [
|
||||||
OutputData(
|
[
|
||||||
id=hit["_source"].get("id"),
|
OutputData(id=hit["_source"].get("id"), score=1.0, payload=hit["_source"].get("payload", {}))
|
||||||
score=1.0,
|
for hit in hits
|
||||||
payload=hit["_source"].get("payload", {})
|
]
|
||||||
)
|
]
|
||||||
for hit in hits
|
|
||||||
]]
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset the index by deleting and recreating it."""
|
"""Reset the index by deleting and recreating it."""
|
||||||
logger.warning(f"Resetting index {self.collection_name}...")
|
logger.warning(f"Resetting index {self.collection_name}...")
|
||||||
|
|||||||
@@ -286,7 +286,7 @@ class PGVector(VectorStoreBase):
|
|||||||
self.cur.close()
|
self.cur.close()
|
||||||
if hasattr(self, "conn"):
|
if hasattr(self, "conn"):
|
||||||
self.conn.close()
|
self.conn.close()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset the index by deleting and recreating it."""
|
"""Reset the index by deleting and recreating it."""
|
||||||
logger.warning(f"Resetting index {self.collection_name}...")
|
logger.warning(f"Resetting index {self.collection_name}...")
|
||||||
|
|||||||
@@ -232,7 +232,7 @@ class Qdrant(VectorStoreBase):
|
|||||||
with_vectors=False,
|
with_vectors=False,
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset the index by deleting and recreating it."""
|
"""Reset the index by deleting and recreating it."""
|
||||||
logger.warning(f"Resetting index {self.collection_name}...")
|
logger.warning(f"Resetting index {self.collection_name}...")
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ class RedisDB(VectorStoreBase):
|
|||||||
The created index object.
|
The created index object.
|
||||||
"""
|
"""
|
||||||
# Use provided parameters or fall back to instance attributes
|
# Use provided parameters or fall back to instance attributes
|
||||||
collection_name = name or self.schema['index']['name']
|
collection_name = name or self.schema["index"]["name"]
|
||||||
embedding_dims = vector_size or self.embedding_model_dims
|
embedding_dims = vector_size or self.embedding_model_dims
|
||||||
distance_metric = distance or "cosine"
|
distance_metric = distance or "cosine"
|
||||||
|
|
||||||
@@ -237,17 +237,16 @@ class RedisDB(VectorStoreBase):
|
|||||||
"""
|
"""
|
||||||
Reset the index by deleting and recreating it.
|
Reset the index by deleting and recreating it.
|
||||||
"""
|
"""
|
||||||
collection_name = self.schema['index']['name']
|
collection_name = self.schema["index"]["name"]
|
||||||
logger.warning(f"Resetting index {collection_name}...")
|
logger.warning(f"Resetting index {collection_name}...")
|
||||||
self.delete_col()
|
self.delete_col()
|
||||||
|
|
||||||
self.index = SearchIndex.from_dict(self.schema)
|
self.index = SearchIndex.from_dict(self.schema)
|
||||||
self.index.set_client(self.client)
|
self.index.set_client(self.client)
|
||||||
self.index.create(overwrite=True)
|
self.index.create(overwrite=True)
|
||||||
|
|
||||||
#or use
|
|
||||||
#self.create_col(collection_name, self.embedding_model_dims)
|
|
||||||
|
|
||||||
|
# or use
|
||||||
|
# self.create_col(collection_name, self.embedding_model_dims)
|
||||||
|
|
||||||
# Recreate the index with the same parameters
|
# Recreate the index with the same parameters
|
||||||
self.create_col(collection_name, self.embedding_model_dims)
|
self.create_col(collection_name, self.embedding_model_dims)
|
||||||
|
|||||||
@@ -229,7 +229,7 @@ class Supabase(VectorStoreBase):
|
|||||||
records = self.collection.fetch(ids=ids)
|
records = self.collection.fetch(ids=ids)
|
||||||
|
|
||||||
return [[OutputData(id=str(record[0]), score=None, payload=record[2]) for record in records]]
|
return [[OutputData(id=str(record[0]), score=None, payload=record[2]) for record in records]]
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset the index by deleting and recreating it."""
|
"""Reset the index by deleting and recreating it."""
|
||||||
logger.warning(f"Resetting index {self.collection_name}...")
|
logger.warning(f"Resetting index {self.collection_name}...")
|
||||||
|
|||||||
@@ -285,10 +285,9 @@ class UpstashVector(VectorStoreBase):
|
|||||||
- Per-namespace vector and pending vector counts
|
- Per-namespace vector and pending vector counts
|
||||||
"""
|
"""
|
||||||
return self.client.info()
|
return self.client.info()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
Reset the Upstash Vector index.
|
Reset the Upstash Vector index.
|
||||||
"""
|
"""
|
||||||
self.delete_col()
|
self.delete_col()
|
||||||
|
|
||||||
|
|||||||
@@ -308,7 +308,7 @@ class Weaviate(VectorStoreBase):
|
|||||||
payload["id"] = str(obj.uuid).split("'")[0]
|
payload["id"] = str(obj.uuid).split("'")[0]
|
||||||
results.append(OutputData(id=str(obj.uuid).split("'")[0], score=1.0, payload=payload))
|
results.append(OutputData(id=str(obj.uuid).split("'")[0], score=1.0, payload=payload))
|
||||||
return [results]
|
return [results]
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset the index by deleting and recreating it."""
|
"""Reset the index by deleting and recreating it."""
|
||||||
logger.warning(f"Resetting index {self.collection_name}...")
|
logger.warning(f"Resetting index {self.collection_name}...")
|
||||||
|
|||||||
@@ -44,31 +44,14 @@ DEFAULT_CONFIG = {
|
|||||||
"user": POSTGRES_USER,
|
"user": POSTGRES_USER,
|
||||||
"password": POSTGRES_PASSWORD,
|
"password": POSTGRES_PASSWORD,
|
||||||
"collection_name": POSTGRES_COLLECTION_NAME,
|
"collection_name": POSTGRES_COLLECTION_NAME,
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
"graph_store": {
|
"graph_store": {
|
||||||
"provider": "neo4j",
|
"provider": "neo4j",
|
||||||
"config": {
|
"config": {"url": NEO4J_URI, "username": NEO4J_USERNAME, "password": NEO4J_PASSWORD},
|
||||||
"url": NEO4J_URI,
|
|
||||||
"username": NEO4J_USERNAME,
|
|
||||||
"password": NEO4J_PASSWORD
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"llm": {
|
|
||||||
"provider": "openai",
|
|
||||||
"config": {
|
|
||||||
"api_key": OPENAI_API_KEY,
|
|
||||||
"temperature": 0.2,
|
|
||||||
"model": "gpt-4o"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"embedder": {
|
|
||||||
"provider": "openai",
|
|
||||||
"config": {
|
|
||||||
"api_key": OPENAI_API_KEY,
|
|
||||||
"model": "text-embedding-3-small"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
|
"llm": {"provider": "openai", "config": {"api_key": OPENAI_API_KEY, "temperature": 0.2, "model": "gpt-4o"}},
|
||||||
|
"embedder": {"provider": "openai", "config": {"api_key": OPENAI_API_KEY, "model": "text-embedding-3-small"}},
|
||||||
"history_db_path": HISTORY_DB_PATH,
|
"history_db_path": HISTORY_DB_PATH,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -115,9 +98,7 @@ def set_config(config: Dict[str, Any]):
|
|||||||
def add_memory(memory_create: MemoryCreate):
|
def add_memory(memory_create: MemoryCreate):
|
||||||
"""Store new memories."""
|
"""Store new memories."""
|
||||||
if not any([memory_create.user_id, memory_create.agent_id, memory_create.run_id]):
|
if not any([memory_create.user_id, memory_create.agent_id, memory_create.run_id]):
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=400, detail="At least one identifier (user_id, agent_id, run_id) is required.")
|
||||||
status_code=400, detail="At least one identifier (user_id, agent_id, run_id) is required."
|
|
||||||
)
|
|
||||||
|
|
||||||
params = {k: v for k, v in memory_create.model_dump().items() if v is not None and k != "messages"}
|
params = {k: v for k, v in memory_create.model_dump().items() if v is not None and k != "messages"}
|
||||||
try:
|
try:
|
||||||
@@ -138,7 +119,9 @@ def get_all_memories(
|
|||||||
if not any([user_id, run_id, agent_id]):
|
if not any([user_id, run_id, agent_id]):
|
||||||
raise HTTPException(status_code=400, detail="At least one identifier is required.")
|
raise HTTPException(status_code=400, detail="At least one identifier is required.")
|
||||||
try:
|
try:
|
||||||
params = {k: v for k, v in {"user_id": user_id, "run_id": run_id, "agent_id": agent_id}.items() if v is not None}
|
params = {
|
||||||
|
k: v for k, v in {"user_id": user_id, "run_id": run_id, "agent_id": agent_id}.items() if v is not None
|
||||||
|
}
|
||||||
return MEMORY_INSTANCE.get_all(**params)
|
return MEMORY_INSTANCE.get_all(**params)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception("Error in get_all_memories:")
|
logging.exception("Error in get_all_memories:")
|
||||||
@@ -207,7 +190,9 @@ def delete_all_memories(
|
|||||||
if not any([user_id, run_id, agent_id]):
|
if not any([user_id, run_id, agent_id]):
|
||||||
raise HTTPException(status_code=400, detail="At least one identifier is required.")
|
raise HTTPException(status_code=400, detail="At least one identifier is required.")
|
||||||
try:
|
try:
|
||||||
params = {k: v for k, v in {"user_id": user_id, "run_id": run_id, "agent_id": agent_id}.items() if v is not None}
|
params = {
|
||||||
|
k: v for k, v in {"user_id": user_id, "run_id": run_id, "agent_id": agent_id}.items() if v is not None
|
||||||
|
}
|
||||||
MEMORY_INSTANCE.delete_all(**params)
|
MEMORY_INSTANCE.delete_all(**params)
|
||||||
return {"message": "All relevant memories deleted"}
|
return {"message": "All relevant memories deleted"}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -229,4 +214,4 @@ def reset_memory():
|
|||||||
@app.get("/", summary="Redirect to the OpenAPI documentation", include_in_schema=False)
|
@app.get("/", summary="Redirect to the OpenAPI documentation", include_in_schema=False)
|
||||||
def home():
|
def home():
|
||||||
"""Redirect to the OpenAPI documentation."""
|
"""Redirect to the OpenAPI documentation."""
|
||||||
return RedirectResponse(url='/docs')
|
return RedirectResponse(url="/docs")
|
||||||
|
|||||||
@@ -5,13 +5,15 @@ def test_get_update_memory_messages():
|
|||||||
retrieved_old_memory_dict = [{"id": "1", "text": "old memory 1"}]
|
retrieved_old_memory_dict = [{"id": "1", "text": "old memory 1"}]
|
||||||
response_content = ["new fact"]
|
response_content = ["new fact"]
|
||||||
custom_update_memory_prompt = "custom prompt determining memory update"
|
custom_update_memory_prompt = "custom prompt determining memory update"
|
||||||
|
|
||||||
## When custom update memory prompt is provided
|
## When custom update memory prompt is provided
|
||||||
##
|
##
|
||||||
result = prompts.get_update_memory_messages(retrieved_old_memory_dict, response_content, custom_update_memory_prompt)
|
result = prompts.get_update_memory_messages(
|
||||||
|
retrieved_old_memory_dict, response_content, custom_update_memory_prompt
|
||||||
|
)
|
||||||
assert result.startswith(custom_update_memory_prompt)
|
assert result.startswith(custom_update_memory_prompt)
|
||||||
|
|
||||||
## When custom update memory prompt is not provided
|
## When custom update memory prompt is not provided
|
||||||
##
|
##
|
||||||
result = prompts.get_update_memory_messages(retrieved_old_memory_dict, response_content, None)
|
result = prompts.get_update_memory_messages(retrieved_old_memory_dict, response_content, None)
|
||||||
assert result.startswith(prompts.DEFAULT_UPDATE_MEMORY_PROMPT)
|
assert result.startswith(prompts.DEFAULT_UPDATE_MEMORY_PROMPT)
|
||||||
|
|||||||
@@ -10,9 +10,7 @@ from mem0.embeddings.lmstudio import LMStudioEmbedding
|
|||||||
def mock_lm_studio_client():
|
def mock_lm_studio_client():
|
||||||
with patch("mem0.embeddings.lmstudio.OpenAI") as mock_openai:
|
with patch("mem0.embeddings.lmstudio.OpenAI") as mock_openai:
|
||||||
mock_client = Mock()
|
mock_client = Mock()
|
||||||
mock_client.embeddings.create.return_value = Mock(
|
mock_client.embeddings.create.return_value = Mock(data=[Mock(embedding=[0.1, 0.2, 0.3, 0.4, 0.5])])
|
||||||
data=[Mock(embedding=[0.1, 0.2, 0.3, 0.4, 0.5])]
|
|
||||||
)
|
|
||||||
mock_openai.return_value = mock_client
|
mock_openai.return_value = mock_client
|
||||||
yield mock_client
|
yield mock_client
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,9 @@ def test_embed_default_model(mock_openai_client):
|
|||||||
|
|
||||||
result = embedder.embed("Hello world")
|
result = embedder.embed("Hello world")
|
||||||
|
|
||||||
mock_openai_client.embeddings.create.assert_called_once_with(input=["Hello world"], model="text-embedding-3-small", dimensions = 1536)
|
mock_openai_client.embeddings.create.assert_called_once_with(
|
||||||
|
input=["Hello world"], model="text-embedding-3-small", dimensions=1536
|
||||||
|
)
|
||||||
assert result == [0.1, 0.2, 0.3]
|
assert result == [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
|
|
||||||
@@ -37,7 +39,7 @@ def test_embed_custom_model(mock_openai_client):
|
|||||||
result = embedder.embed("Test embedding")
|
result = embedder.embed("Test embedding")
|
||||||
|
|
||||||
mock_openai_client.embeddings.create.assert_called_once_with(
|
mock_openai_client.embeddings.create.assert_called_once_with(
|
||||||
input=["Test embedding"], model="text-embedding-2-medium", dimensions = 1024
|
input=["Test embedding"], model="text-embedding-2-medium", dimensions=1024
|
||||||
)
|
)
|
||||||
assert result == [0.4, 0.5, 0.6]
|
assert result == [0.4, 0.5, 0.6]
|
||||||
|
|
||||||
@@ -51,7 +53,9 @@ def test_embed_removes_newlines(mock_openai_client):
|
|||||||
|
|
||||||
result = embedder.embed("Hello\nworld")
|
result = embedder.embed("Hello\nworld")
|
||||||
|
|
||||||
mock_openai_client.embeddings.create.assert_called_once_with(input=["Hello world"], model="text-embedding-3-small", dimensions = 1536)
|
mock_openai_client.embeddings.create.assert_called_once_with(
|
||||||
|
input=["Hello world"], model="text-embedding-3-small", dimensions=1536
|
||||||
|
)
|
||||||
assert result == [0.7, 0.8, 0.9]
|
assert result == [0.7, 0.8, 0.9]
|
||||||
|
|
||||||
|
|
||||||
@@ -65,7 +69,7 @@ def test_embed_without_api_key_env_var(mock_openai_client):
|
|||||||
result = embedder.embed("Testing API key")
|
result = embedder.embed("Testing API key")
|
||||||
|
|
||||||
mock_openai_client.embeddings.create.assert_called_once_with(
|
mock_openai_client.embeddings.create.assert_called_once_with(
|
||||||
input=["Testing API key"], model="text-embedding-3-small", dimensions = 1536
|
input=["Testing API key"], model="text-embedding-3-small", dimensions=1536
|
||||||
)
|
)
|
||||||
assert result == [1.0, 1.1, 1.2]
|
assert result == [1.0, 1.1, 1.2]
|
||||||
|
|
||||||
@@ -81,6 +85,6 @@ def test_embed_uses_environment_api_key(mock_openai_client, monkeypatch):
|
|||||||
result = embedder.embed("Environment key test")
|
result = embedder.embed("Environment key test")
|
||||||
|
|
||||||
mock_openai_client.embeddings.create.assert_called_once_with(
|
mock_openai_client.embeddings.create.assert_called_once_with(
|
||||||
input=["Environment key test"], model="text-embedding-3-small", dimensions = 1536
|
input=["Environment key test"], model="text-embedding-3-small", dimensions=1536
|
||||||
)
|
)
|
||||||
assert result == [1.3, 1.4, 1.5]
|
assert result == [1.3, 1.4, 1.5]
|
||||||
|
|||||||
@@ -24,11 +24,20 @@ def mock_config():
|
|||||||
with patch("mem0.configs.embeddings.base.BaseEmbedderConfig") as mock_config:
|
with patch("mem0.configs.embeddings.base.BaseEmbedderConfig") as mock_config:
|
||||||
mock_config.return_value.vertex_credentials_json = "/path/to/credentials.json"
|
mock_config.return_value.vertex_credentials_json = "/path/to/credentials.json"
|
||||||
yield mock_config
|
yield mock_config
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_embedding_types():
|
def mock_embedding_types():
|
||||||
return ["SEMANTIC_SIMILARITY", "CLASSIFICATION", "CLUSTERING", "RETRIEVAL_DOCUMENT", "RETRIEVAL_QUERY", "QUESTION_ANSWERING", "FACT_VERIFICATION", "CODE_RETRIEVAL_QUERY"]
|
return [
|
||||||
|
"SEMANTIC_SIMILARITY",
|
||||||
|
"CLASSIFICATION",
|
||||||
|
"CLUSTERING",
|
||||||
|
"RETRIEVAL_DOCUMENT",
|
||||||
|
"RETRIEVAL_QUERY",
|
||||||
|
"QUESTION_ANSWERING",
|
||||||
|
"FACT_VERIFICATION",
|
||||||
|
"CODE_RETRIEVAL_QUERY",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -79,30 +88,31 @@ def test_embed_custom_model(mock_text_embedding_model, mock_os_environ, mock_con
|
|||||||
assert result == [0.4, 0.5, 0.6]
|
assert result == [0.4, 0.5, 0.6]
|
||||||
|
|
||||||
|
|
||||||
@patch("mem0.embeddings.vertexai.TextEmbeddingModel")
|
@patch("mem0.embeddings.vertexai.TextEmbeddingModel")
|
||||||
def test_embed_with_memory_action(mock_text_embedding_model, mock_os_environ, mock_config, mock_embedding_types, mock_text_embedding_input):
|
def test_embed_with_memory_action(
|
||||||
|
mock_text_embedding_model, mock_os_environ, mock_config, mock_embedding_types, mock_text_embedding_input
|
||||||
|
):
|
||||||
mock_config.return_value.model = "text-embedding-004"
|
mock_config.return_value.model = "text-embedding-004"
|
||||||
mock_config.return_value.embedding_dims = 256
|
mock_config.return_value.embedding_dims = 256
|
||||||
|
|
||||||
for embedding_type in mock_embedding_types:
|
for embedding_type in mock_embedding_types:
|
||||||
|
|
||||||
mock_config.return_value.memory_add_embedding_type = embedding_type
|
mock_config.return_value.memory_add_embedding_type = embedding_type
|
||||||
mock_config.return_value.memory_update_embedding_type = embedding_type
|
mock_config.return_value.memory_update_embedding_type = embedding_type
|
||||||
mock_config.return_value.memory_search_embedding_type = embedding_type
|
mock_config.return_value.memory_search_embedding_type = embedding_type
|
||||||
|
|
||||||
config = mock_config()
|
config = mock_config()
|
||||||
embedder = VertexAIEmbedding(config)
|
embedder = VertexAIEmbedding(config)
|
||||||
|
|
||||||
mock_text_embedding_model.from_pretrained.assert_called_with("text-embedding-004")
|
mock_text_embedding_model.from_pretrained.assert_called_with("text-embedding-004")
|
||||||
|
|
||||||
for memory_action in ["add", "update", "search"]:
|
for memory_action in ["add", "update", "search"]:
|
||||||
embedder.embed("Hello world", memory_action=memory_action)
|
embedder.embed("Hello world", memory_action=memory_action)
|
||||||
|
|
||||||
mock_text_embedding_input.assert_called_with(text="Hello world", task_type=embedding_type)
|
mock_text_embedding_input.assert_called_with(text="Hello world", task_type=embedding_type)
|
||||||
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.assert_called_with(
|
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.assert_called_with(
|
||||||
texts=[mock_text_embedding_input("Hello world", embedding_type)], output_dimensionality=256
|
texts=[mock_text_embedding_input("Hello world", embedding_type)], output_dimensionality=256
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@patch("mem0.embeddings.vertexai.os")
|
@patch("mem0.embeddings.vertexai.os")
|
||||||
def test_credentials_from_environment(mock_os, mock_text_embedding_model, mock_config):
|
def test_credentials_from_environment(mock_os, mock_text_embedding_model, mock_config):
|
||||||
@@ -137,15 +147,15 @@ def test_embed_with_different_dimensions(mock_text_embedding_model, mock_os_envi
|
|||||||
result = embedder.embed("Large embedding test")
|
result = embedder.embed("Large embedding test")
|
||||||
|
|
||||||
assert result == [0.1] * 1024
|
assert result == [0.1] * 1024
|
||||||
|
|
||||||
|
|
||||||
@patch("mem0.embeddings.vertexai.TextEmbeddingModel")
|
|
||||||
|
@patch("mem0.embeddings.vertexai.TextEmbeddingModel")
|
||||||
def test_invalid_memory_action(mock_text_embedding_model, mock_config):
|
def test_invalid_memory_action(mock_text_embedding_model, mock_config):
|
||||||
mock_config.return_value.model = "text-embedding-004"
|
mock_config.return_value.model = "text-embedding-004"
|
||||||
mock_config.return_value.embedding_dims = 256
|
mock_config.return_value.embedding_dims = 256
|
||||||
|
|
||||||
config = mock_config()
|
config = mock_config()
|
||||||
embedder = VertexAIEmbedding(config)
|
embedder = VertexAIEmbedding(config)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
embedder.embed("Hello world", memory_action="invalid_action")
|
embedder.embed("Hello world", memory_action="invalid_action")
|
||||||
|
|||||||
@@ -127,4 +127,4 @@ def test_generate_with_http_proxies(default_headers):
|
|||||||
api_version=None,
|
api_version=None,
|
||||||
default_headers=default_headers,
|
default_headers=default_headers,
|
||||||
)
|
)
|
||||||
mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000")
|
mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000")
|
||||||
|
|||||||
@@ -31,12 +31,12 @@ def test_deepseek_llm_base_url():
|
|||||||
# case3: with config.deepseek_base_url
|
# case3: with config.deepseek_base_url
|
||||||
config_base_url = "https://api.config.com/v1/"
|
config_base_url = "https://api.config.com/v1/"
|
||||||
config = BaseLlmConfig(
|
config = BaseLlmConfig(
|
||||||
model="deepseek-chat",
|
model="deepseek-chat",
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
top_p=1.0,
|
top_p=1.0,
|
||||||
api_key="api_key",
|
api_key="api_key",
|
||||||
deepseek_base_url=config_base_url
|
deepseek_base_url=config_base_url,
|
||||||
)
|
)
|
||||||
llm = DeepSeekLLM(config)
|
llm = DeepSeekLLM(config)
|
||||||
assert str(llm.client.base_url) == config_base_url
|
assert str(llm.client.base_url) == config_base_url
|
||||||
@@ -99,16 +99,16 @@ def test_generate_response_with_tools(mock_deepseek_client):
|
|||||||
response = llm.generate_response(messages, tools=tools)
|
response = llm.generate_response(messages, tools=tools)
|
||||||
|
|
||||||
mock_deepseek_client.chat.completions.create.assert_called_once_with(
|
mock_deepseek_client.chat.completions.create.assert_called_once_with(
|
||||||
model="deepseek-chat",
|
model="deepseek-chat",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
top_p=1.0,
|
top_p=1.0,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice="auto"
|
tool_choice="auto",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response["content"] == "I've added the memory for you."
|
assert response["content"] == "I've added the memory for you."
|
||||||
assert len(response["tool_calls"]) == 1
|
assert len(response["tool_calls"]) == 1
|
||||||
assert response["tool_calls"][0]["name"] == "add_memory"
|
assert response["tool_calls"][0]["name"] == "add_memory"
|
||||||
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
|
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ try:
|
|||||||
from langchain.chat_models.base import BaseChatModel
|
from langchain.chat_models.base import BaseChatModel
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
BaseChatModel = MagicMock
|
BaseChatModel = MagicMock
|
||||||
|
|
||||||
|
|
||||||
@@ -24,16 +25,11 @@ def mock_langchain_model():
|
|||||||
def test_langchain_initialization(mock_langchain_model):
|
def test_langchain_initialization(mock_langchain_model):
|
||||||
"""Test that LangchainLLM initializes correctly with a valid model."""
|
"""Test that LangchainLLM initializes correctly with a valid model."""
|
||||||
# Create a config with the model instance directly
|
# Create a config with the model instance directly
|
||||||
config = BaseLlmConfig(
|
config = BaseLlmConfig(model=mock_langchain_model, temperature=0.7, max_tokens=100, api_key="test-api-key")
|
||||||
model=mock_langchain_model,
|
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=100,
|
|
||||||
api_key="test-api-key"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize the LangchainLLM
|
# Initialize the LangchainLLM
|
||||||
llm = LangchainLLM(config)
|
llm = LangchainLLM(config)
|
||||||
|
|
||||||
# Verify the model was correctly assigned
|
# Verify the model was correctly assigned
|
||||||
assert llm.langchain_model == mock_langchain_model
|
assert llm.langchain_model == mock_langchain_model
|
||||||
|
|
||||||
@@ -41,35 +37,30 @@ def test_langchain_initialization(mock_langchain_model):
|
|||||||
def test_generate_response(mock_langchain_model):
|
def test_generate_response(mock_langchain_model):
|
||||||
"""Test that generate_response correctly processes messages and returns a response."""
|
"""Test that generate_response correctly processes messages and returns a response."""
|
||||||
# Create a config with the model instance
|
# Create a config with the model instance
|
||||||
config = BaseLlmConfig(
|
config = BaseLlmConfig(model=mock_langchain_model, temperature=0.7, max_tokens=100, api_key="test-api-key")
|
||||||
model=mock_langchain_model,
|
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=100,
|
|
||||||
api_key="test-api-key"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize the LangchainLLM
|
# Initialize the LangchainLLM
|
||||||
llm = LangchainLLM(config)
|
llm = LangchainLLM(config)
|
||||||
|
|
||||||
# Create test messages
|
# Create test messages
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "user", "content": "Hello, how are you?"},
|
{"role": "user", "content": "Hello, how are you?"},
|
||||||
{"role": "assistant", "content": "I'm doing well! How can I help you?"},
|
{"role": "assistant", "content": "I'm doing well! How can I help you?"},
|
||||||
{"role": "user", "content": "Tell me a joke."}
|
{"role": "user", "content": "Tell me a joke."},
|
||||||
]
|
]
|
||||||
|
|
||||||
# Get response
|
# Get response
|
||||||
response = llm.generate_response(messages)
|
response = llm.generate_response(messages)
|
||||||
|
|
||||||
# Verify the correct message format was passed to the model
|
# Verify the correct message format was passed to the model
|
||||||
expected_langchain_messages = [
|
expected_langchain_messages = [
|
||||||
("system", "You are a helpful assistant."),
|
("system", "You are a helpful assistant."),
|
||||||
("human", "Hello, how are you?"),
|
("human", "Hello, how are you?"),
|
||||||
("ai", "I'm doing well! How can I help you?"),
|
("ai", "I'm doing well! How can I help you?"),
|
||||||
("human", "Tell me a joke.")
|
("human", "Tell me a joke."),
|
||||||
]
|
]
|
||||||
|
|
||||||
mock_langchain_model.invoke.assert_called_once()
|
mock_langchain_model.invoke.assert_called_once()
|
||||||
# Extract the first argument of the first call
|
# Extract the first argument of the first call
|
||||||
actual_messages = mock_langchain_model.invoke.call_args[0][0]
|
actual_messages = mock_langchain_model.invoke.call_args[0][0]
|
||||||
@@ -79,25 +70,15 @@ def test_generate_response(mock_langchain_model):
|
|||||||
|
|
||||||
def test_invalid_model():
|
def test_invalid_model():
|
||||||
"""Test that LangchainLLM raises an error with an invalid model."""
|
"""Test that LangchainLLM raises an error with an invalid model."""
|
||||||
config = BaseLlmConfig(
|
config = BaseLlmConfig(model="not-a-valid-model-instance", temperature=0.7, max_tokens=100, api_key="test-api-key")
|
||||||
model="not-a-valid-model-instance",
|
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=100,
|
|
||||||
api_key="test-api-key"
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="`model` must be an instance of BaseChatModel"):
|
with pytest.raises(ValueError, match="`model` must be an instance of BaseChatModel"):
|
||||||
LangchainLLM(config)
|
LangchainLLM(config)
|
||||||
|
|
||||||
|
|
||||||
def test_missing_model():
|
def test_missing_model():
|
||||||
"""Test that LangchainLLM raises an error when model is None."""
|
"""Test that LangchainLLM raises an error when model is None."""
|
||||||
config = BaseLlmConfig(
|
config = BaseLlmConfig(model=None, temperature=0.7, max_tokens=100, api_key="test-api-key")
|
||||||
model=None,
|
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=100,
|
|
||||||
api_key="test-api-key"
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="`model` parameter is required"):
|
with pytest.raises(ValueError, match="`model` parameter is required"):
|
||||||
LangchainLLM(config)
|
LangchainLLM(config)
|
||||||
|
|||||||
@@ -11,9 +11,7 @@ def mock_lm_studio_client():
|
|||||||
with patch("mem0.llms.lmstudio.OpenAI") as mock_openai: # Corrected path
|
with patch("mem0.llms.lmstudio.OpenAI") as mock_openai: # Corrected path
|
||||||
mock_client = Mock()
|
mock_client = Mock()
|
||||||
mock_client.chat.completions.create.return_value = Mock(
|
mock_client.chat.completions.create.return_value = Mock(
|
||||||
choices=[
|
choices=[Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
|
||||||
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
mock_openai.return_value = mock_client
|
mock_openai.return_value = mock_client
|
||||||
yield mock_client
|
yield mock_client
|
||||||
|
|||||||
@@ -10,18 +10,19 @@ def _setup_mocks(mocker):
|
|||||||
"""Helper to setup common mocks for both sync and async fixtures"""
|
"""Helper to setup common mocks for both sync and async fixtures"""
|
||||||
mock_embedder = mocker.MagicMock()
|
mock_embedder = mocker.MagicMock()
|
||||||
mock_embedder.return_value.embed.return_value = [0.1, 0.2, 0.3]
|
mock_embedder.return_value.embed.return_value = [0.1, 0.2, 0.3]
|
||||||
mocker.patch('mem0.utils.factory.EmbedderFactory.create', mock_embedder)
|
mocker.patch("mem0.utils.factory.EmbedderFactory.create", mock_embedder)
|
||||||
|
|
||||||
mock_vector_store = mocker.MagicMock()
|
mock_vector_store = mocker.MagicMock()
|
||||||
mock_vector_store.return_value.search.return_value = []
|
mock_vector_store.return_value.search.return_value = []
|
||||||
mocker.patch('mem0.utils.factory.VectorStoreFactory.create',
|
mocker.patch(
|
||||||
side_effect=[mock_vector_store.return_value, mocker.MagicMock()])
|
"mem0.utils.factory.VectorStoreFactory.create", side_effect=[mock_vector_store.return_value, mocker.MagicMock()]
|
||||||
|
)
|
||||||
|
|
||||||
mock_llm = mocker.MagicMock()
|
mock_llm = mocker.MagicMock()
|
||||||
mocker.patch('mem0.utils.factory.LlmFactory.create', mock_llm)
|
mocker.patch("mem0.utils.factory.LlmFactory.create", mock_llm)
|
||||||
|
|
||||||
mocker.patch('mem0.memory.storage.SQLiteManager', mocker.MagicMock())
|
mocker.patch("mem0.memory.storage.SQLiteManager", mocker.MagicMock())
|
||||||
|
|
||||||
return mock_llm, mock_vector_store
|
return mock_llm, mock_vector_store
|
||||||
|
|
||||||
|
|
||||||
@@ -30,29 +31,26 @@ class TestAddToVectorStoreErrors:
|
|||||||
def mock_memory(self, mocker):
|
def mock_memory(self, mocker):
|
||||||
"""Fixture that returns a Memory instance with mocker-based mocks"""
|
"""Fixture that returns a Memory instance with mocker-based mocks"""
|
||||||
mock_llm, _ = _setup_mocks(mocker)
|
mock_llm, _ = _setup_mocks(mocker)
|
||||||
|
|
||||||
memory = Memory()
|
memory = Memory()
|
||||||
memory.config = mocker.MagicMock()
|
memory.config = mocker.MagicMock()
|
||||||
memory.config.custom_fact_extraction_prompt = None
|
memory.config.custom_fact_extraction_prompt = None
|
||||||
memory.config.custom_update_memory_prompt = None
|
memory.config.custom_update_memory_prompt = None
|
||||||
memory.api_version = "v1.1"
|
memory.api_version = "v1.1"
|
||||||
|
|
||||||
return memory
|
return memory
|
||||||
|
|
||||||
def test_empty_llm_response_fact_extraction(self, mock_memory, caplog):
|
def test_empty_llm_response_fact_extraction(self, mock_memory, caplog):
|
||||||
"""Test empty response from LLM during fact extraction"""
|
"""Test empty response from LLM during fact extraction"""
|
||||||
# Setup
|
# Setup
|
||||||
mock_memory.llm.generate_response.return_value = ""
|
mock_memory.llm.generate_response.return_value = ""
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
with caplog.at_level(logging.ERROR):
|
with caplog.at_level(logging.ERROR):
|
||||||
result = mock_memory._add_to_vector_store(
|
result = mock_memory._add_to_vector_store(
|
||||||
messages=[{"role": "user", "content": "test"}],
|
messages=[{"role": "user", "content": "test"}], metadata={}, filters={}, infer=True
|
||||||
metadata={},
|
|
||||||
filters={},
|
|
||||||
infer=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
assert mock_memory.llm.generate_response.call_count == 2
|
assert mock_memory.llm.generate_response.call_count == 2
|
||||||
assert result == [] # Should return empty list when no memories processed
|
assert result == [] # Should return empty list when no memories processed
|
||||||
@@ -62,20 +60,14 @@ class TestAddToVectorStoreErrors:
|
|||||||
"""Test empty response from LLM during memory actions"""
|
"""Test empty response from LLM during memory actions"""
|
||||||
# Setup
|
# Setup
|
||||||
# First call returns valid JSON, second call returns empty string
|
# First call returns valid JSON, second call returns empty string
|
||||||
mock_memory.llm.generate_response.side_effect = [
|
mock_memory.llm.generate_response.side_effect = ['{"facts": ["test fact"]}', ""]
|
||||||
'{"facts": ["test fact"]}',
|
|
||||||
""
|
|
||||||
]
|
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
with caplog.at_level(logging.ERROR):
|
with caplog.at_level(logging.ERROR):
|
||||||
result = mock_memory._add_to_vector_store(
|
result = mock_memory._add_to_vector_store(
|
||||||
messages=[{"role": "user", "content": "test"}],
|
messages=[{"role": "user", "content": "test"}], metadata={}, filters={}, infer=True
|
||||||
metadata={},
|
|
||||||
filters={},
|
|
||||||
infer=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
assert mock_memory.llm.generate_response.call_count == 2
|
assert mock_memory.llm.generate_response.call_count == 2
|
||||||
assert result == [] # Should return empty list when no memories processed
|
assert result == [] # Should return empty list when no memories processed
|
||||||
@@ -88,48 +80,39 @@ class TestAsyncAddToVectorStoreErrors:
|
|||||||
def mock_async_memory(self, mocker):
|
def mock_async_memory(self, mocker):
|
||||||
"""Fixture for AsyncMemory with mocker-based mocks"""
|
"""Fixture for AsyncMemory with mocker-based mocks"""
|
||||||
mock_llm, _ = _setup_mocks(mocker)
|
mock_llm, _ = _setup_mocks(mocker)
|
||||||
|
|
||||||
memory = AsyncMemory()
|
memory = AsyncMemory()
|
||||||
memory.config = mocker.MagicMock()
|
memory.config = mocker.MagicMock()
|
||||||
memory.config.custom_fact_extraction_prompt = None
|
memory.config.custom_fact_extraction_prompt = None
|
||||||
memory.config.custom_update_memory_prompt = None
|
memory.config.custom_update_memory_prompt = None
|
||||||
memory.api_version = "v1.1"
|
memory.api_version = "v1.1"
|
||||||
|
|
||||||
return memory
|
return memory
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_empty_llm_response_fact_extraction(self, mock_async_memory, caplog, mocker):
|
async def test_async_empty_llm_response_fact_extraction(self, mock_async_memory, caplog, mocker):
|
||||||
"""Test empty response in AsyncMemory._add_to_vector_store"""
|
"""Test empty response in AsyncMemory._add_to_vector_store"""
|
||||||
mocker.patch('mem0.utils.factory.EmbedderFactory.create', return_value=MagicMock())
|
mocker.patch("mem0.utils.factory.EmbedderFactory.create", return_value=MagicMock())
|
||||||
mock_async_memory.llm.generate_response.return_value = ""
|
mock_async_memory.llm.generate_response.return_value = ""
|
||||||
|
|
||||||
with caplog.at_level(logging.ERROR):
|
with caplog.at_level(logging.ERROR):
|
||||||
result = await mock_async_memory._add_to_vector_store(
|
result = await mock_async_memory._add_to_vector_store(
|
||||||
messages=[{"role": "user", "content": "test"}],
|
messages=[{"role": "user", "content": "test"}], metadata={}, filters={}, infer=True
|
||||||
metadata={},
|
|
||||||
filters={},
|
|
||||||
infer=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result == []
|
assert result == []
|
||||||
assert "Error in new_retrieved_facts" in caplog.text
|
assert "Error in new_retrieved_facts" in caplog.text
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_empty_llm_response_memory_actions(self, mock_async_memory, caplog, mocker):
|
async def test_async_empty_llm_response_memory_actions(self, mock_async_memory, caplog, mocker):
|
||||||
"""Test empty response in AsyncMemory._add_to_vector_store"""
|
"""Test empty response in AsyncMemory._add_to_vector_store"""
|
||||||
mocker.patch('mem0.utils.factory.EmbedderFactory.create', return_value=MagicMock())
|
mocker.patch("mem0.utils.factory.EmbedderFactory.create", return_value=MagicMock())
|
||||||
mock_async_memory.llm.generate_response.side_effect = [
|
mock_async_memory.llm.generate_response.side_effect = ['{"facts": ["test fact"]}', ""]
|
||||||
'{"facts": ["test fact"]}',
|
|
||||||
""
|
|
||||||
]
|
|
||||||
|
|
||||||
with caplog.at_level(logging.ERROR):
|
with caplog.at_level(logging.ERROR):
|
||||||
result = await mock_async_memory._add_to_vector_store(
|
result = await mock_async_memory._add_to_vector_store(
|
||||||
messages=[{"role": "user", "content": "test"}],
|
messages=[{"role": "user", "content": "test"}], metadata={}, filters={}, infer=True
|
||||||
metadata={},
|
|
||||||
filters={},
|
|
||||||
infer=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result == []
|
assert result == []
|
||||||
assert "Invalid JSON response" in caplog.text
|
assert "Invalid JSON response" in caplog.text
|
||||||
|
|||||||
@@ -17,11 +17,13 @@ def mock_openai():
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def memory_instance():
|
def memory_instance():
|
||||||
with patch("mem0.utils.factory.EmbedderFactory") as mock_embedder, patch(
|
with (
|
||||||
"mem0.utils.factory.VectorStoreFactory"
|
patch("mem0.utils.factory.EmbedderFactory") as mock_embedder,
|
||||||
) as mock_vector_store, patch("mem0.utils.factory.LlmFactory") as mock_llm, patch(
|
patch("mem0.utils.factory.VectorStoreFactory") as mock_vector_store,
|
||||||
"mem0.memory.telemetry.capture_event"
|
patch("mem0.utils.factory.LlmFactory") as mock_llm,
|
||||||
), patch("mem0.memory.graph_memory.MemoryGraph"):
|
patch("mem0.memory.telemetry.capture_event"),
|
||||||
|
patch("mem0.memory.graph_memory.MemoryGraph"),
|
||||||
|
):
|
||||||
mock_embedder.create.return_value = Mock()
|
mock_embedder.create.return_value = Mock()
|
||||||
mock_vector_store.create.return_value = Mock()
|
mock_vector_store.create.return_value = Mock()
|
||||||
mock_llm.create.return_value = Mock()
|
mock_llm.create.return_value = Mock()
|
||||||
@@ -30,13 +32,16 @@ def memory_instance():
|
|||||||
config.graph_store.config = {"some_config": "value"}
|
config.graph_store.config = {"some_config": "value"}
|
||||||
return Memory(config)
|
return Memory(config)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def memory_custom_instance():
|
def memory_custom_instance():
|
||||||
with patch("mem0.utils.factory.EmbedderFactory") as mock_embedder, patch(
|
with (
|
||||||
"mem0.utils.factory.VectorStoreFactory"
|
patch("mem0.utils.factory.EmbedderFactory") as mock_embedder,
|
||||||
) as mock_vector_store, patch("mem0.utils.factory.LlmFactory") as mock_llm, patch(
|
patch("mem0.utils.factory.VectorStoreFactory") as mock_vector_store,
|
||||||
"mem0.memory.telemetry.capture_event"
|
patch("mem0.utils.factory.LlmFactory") as mock_llm,
|
||||||
), patch("mem0.memory.graph_memory.MemoryGraph"):
|
patch("mem0.memory.telemetry.capture_event"),
|
||||||
|
patch("mem0.memory.graph_memory.MemoryGraph"),
|
||||||
|
):
|
||||||
mock_embedder.create.return_value = Mock()
|
mock_embedder.create.return_value = Mock()
|
||||||
mock_vector_store.create.return_value = Mock()
|
mock_vector_store.create.return_value = Mock()
|
||||||
mock_llm.create.return_value = Mock()
|
mock_llm.create.return_value = Mock()
|
||||||
@@ -44,7 +49,7 @@ def memory_custom_instance():
|
|||||||
config = MemoryConfig(
|
config = MemoryConfig(
|
||||||
version="v1.1",
|
version="v1.1",
|
||||||
custom_fact_extraction_prompt="custom prompt extracting memory",
|
custom_fact_extraction_prompt="custom prompt extracting memory",
|
||||||
custom_update_memory_prompt="custom prompt determining memory update"
|
custom_update_memory_prompt="custom prompt determining memory update",
|
||||||
)
|
)
|
||||||
config.graph_store.config = {"some_config": "value"}
|
config.graph_store.config = {"some_config": "value"}
|
||||||
return Memory(config)
|
return Memory(config)
|
||||||
@@ -194,7 +199,6 @@ def test_delete_all(memory_instance, version, enable_graph):
|
|||||||
assert result["message"] == "Memories deleted successfully!"
|
assert result["message"] == "Memories deleted successfully!"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"version, enable_graph, expected_result",
|
"version, enable_graph, expected_result",
|
||||||
[
|
[
|
||||||
@@ -242,20 +246,22 @@ def test_get_all(memory_instance, version, enable_graph, expected_result):
|
|||||||
memory_instance.graph.get_all.assert_called_once_with({"user_id": "test_user"}, 100)
|
memory_instance.graph.get_all.assert_called_once_with({"user_id": "test_user"}, 100)
|
||||||
else:
|
else:
|
||||||
memory_instance.graph.get_all.assert_not_called()
|
memory_instance.graph.get_all.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
def test_custom_prompts(memory_custom_instance):
|
def test_custom_prompts(memory_custom_instance):
|
||||||
messages = [{"role": "user", "content": "Test message"}]
|
messages = [{"role": "user", "content": "Test message"}]
|
||||||
memory_custom_instance.llm.generate_response = Mock()
|
memory_custom_instance.llm.generate_response = Mock()
|
||||||
|
|
||||||
with patch("mem0.memory.main.parse_messages", return_value="Test message") as mock_parse_messages:
|
with patch("mem0.memory.main.parse_messages", return_value="Test message") as mock_parse_messages:
|
||||||
with patch("mem0.memory.main.get_update_memory_messages", return_value="custom update memory prompt") as mock_get_update_memory_messages:
|
with patch(
|
||||||
|
"mem0.memory.main.get_update_memory_messages", return_value="custom update memory prompt"
|
||||||
|
) as mock_get_update_memory_messages:
|
||||||
memory_custom_instance.add(messages=messages, user_id="test_user")
|
memory_custom_instance.add(messages=messages, user_id="test_user")
|
||||||
|
|
||||||
## custom prompt
|
## custom prompt
|
||||||
##
|
##
|
||||||
mock_parse_messages.assert_called_once_with(messages)
|
mock_parse_messages.assert_called_once_with(messages)
|
||||||
|
|
||||||
memory_custom_instance.llm.generate_response.assert_any_call(
|
memory_custom_instance.llm.generate_response.assert_any_call(
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": memory_custom_instance.config.custom_fact_extraction_prompt},
|
{"role": "system", "content": memory_custom_instance.config.custom_fact_extraction_prompt},
|
||||||
@@ -263,12 +269,14 @@ def test_custom_prompts(memory_custom_instance):
|
|||||||
],
|
],
|
||||||
response_format={"type": "json_object"},
|
response_format={"type": "json_object"},
|
||||||
)
|
)
|
||||||
|
|
||||||
## custom update memory prompt
|
## custom update memory prompt
|
||||||
##
|
##
|
||||||
mock_get_update_memory_messages.assert_called_once_with([],[],memory_custom_instance.config.custom_update_memory_prompt)
|
mock_get_update_memory_messages.assert_called_once_with(
|
||||||
|
[], [], memory_custom_instance.config.custom_update_memory_prompt
|
||||||
|
)
|
||||||
|
|
||||||
memory_custom_instance.llm.generate_response.assert_any_call(
|
memory_custom_instance.llm.generate_response.assert_any_call(
|
||||||
messages=[{"role": "user", "content": mock_get_update_memory_messages.return_value}],
|
messages=[{"role": "user", "content": mock_get_update_memory_messages.return_value}],
|
||||||
response_format={"type": "json_object"},
|
response_format={"type": "json_object"},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -97,4 +97,4 @@ def test_completions_create_with_system_message(mock_memory_client, mock_litellm
|
|||||||
|
|
||||||
call_args = mock_litellm.completion.call_args[1]
|
call_args = mock_litellm.completion.call_args[1]
|
||||||
assert call_args["messages"][0]["role"] == "system"
|
assert call_args["messages"][0]["role"] == "system"
|
||||||
assert call_args["messages"][0]["content"] == "You are a helpful assistant."
|
assert call_args["messages"][0]["content"] == "You are a helpful assistant."
|
||||||
|
|||||||
@@ -13,13 +13,15 @@ from mem0.vector_stores.azure_ai_search import AzureAISearch
|
|||||||
# Fixture to patch SearchClient and SearchIndexClient and create an instance of AzureAISearch.
|
# Fixture to patch SearchClient and SearchIndexClient and create an instance of AzureAISearch.
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_clients():
|
def mock_clients():
|
||||||
with patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient, \
|
with (
|
||||||
patch("mem0.vector_stores.azure_ai_search.SearchIndexClient") as MockIndexClient, \
|
patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient,
|
||||||
patch("mem0.vector_stores.azure_ai_search.AzureKeyCredential") as MockAzureKeyCredential:
|
patch("mem0.vector_stores.azure_ai_search.SearchIndexClient") as MockIndexClient,
|
||||||
|
patch("mem0.vector_stores.azure_ai_search.AzureKeyCredential") as MockAzureKeyCredential,
|
||||||
|
):
|
||||||
# Create mocked instances for search and index clients.
|
# Create mocked instances for search and index clients.
|
||||||
mock_search_client = MockSearchClient.return_value
|
mock_search_client = MockSearchClient.return_value
|
||||||
mock_index_client = MockIndexClient.return_value
|
mock_index_client = MockIndexClient.return_value
|
||||||
|
|
||||||
# Mock the client._client._config.user_agent_policy.add_user_agent
|
# Mock the client._client._config.user_agent_policy.add_user_agent
|
||||||
mock_search_client._client = MagicMock()
|
mock_search_client._client = MagicMock()
|
||||||
mock_search_client._client._config.user_agent_policy.add_user_agent = Mock()
|
mock_search_client._client._config.user_agent_policy.add_user_agent = Mock()
|
||||||
@@ -62,7 +64,7 @@ def azure_ai_search_instance(mock_clients):
|
|||||||
api_key="test-api-key",
|
api_key="test-api-key",
|
||||||
embedding_model_dims=3,
|
embedding_model_dims=3,
|
||||||
compression_type="binary", # testing binary quantization option
|
compression_type="binary", # testing binary quantization option
|
||||||
use_float16=True
|
use_float16=True,
|
||||||
)
|
)
|
||||||
# Return instance and clients for verification.
|
# Return instance and clients for verification.
|
||||||
return instance, mock_search_client, mock_index_client
|
return instance, mock_search_client, mock_index_client
|
||||||
@@ -70,21 +72,18 @@ def azure_ai_search_instance(mock_clients):
|
|||||||
|
|
||||||
# --- Tests for AzureAISearchConfig ---
|
# --- Tests for AzureAISearchConfig ---
|
||||||
|
|
||||||
|
|
||||||
def test_config_validation_valid():
|
def test_config_validation_valid():
|
||||||
"""Test valid configurations are accepted."""
|
"""Test valid configurations are accepted."""
|
||||||
# Test minimal configuration
|
# Test minimal configuration
|
||||||
config = AzureAISearchConfig(
|
config = AzureAISearchConfig(service_name="test-service", api_key="test-api-key", embedding_model_dims=768)
|
||||||
service_name="test-service",
|
|
||||||
api_key="test-api-key",
|
|
||||||
embedding_model_dims=768
|
|
||||||
)
|
|
||||||
assert config.collection_name == "mem0" # Default value
|
assert config.collection_name == "mem0" # Default value
|
||||||
assert config.service_name == "test-service"
|
assert config.service_name == "test-service"
|
||||||
assert config.api_key == "test-api-key"
|
assert config.api_key == "test-api-key"
|
||||||
assert config.embedding_model_dims == 768
|
assert config.embedding_model_dims == 768
|
||||||
assert config.compression_type is None
|
assert config.compression_type is None
|
||||||
assert config.use_float16 is False
|
assert config.use_float16 is False
|
||||||
|
|
||||||
# Test with all optional parameters
|
# Test with all optional parameters
|
||||||
config = AzureAISearchConfig(
|
config = AzureAISearchConfig(
|
||||||
collection_name="custom-index",
|
collection_name="custom-index",
|
||||||
@@ -92,7 +91,7 @@ def test_config_validation_valid():
|
|||||||
api_key="test-api-key",
|
api_key="test-api-key",
|
||||||
embedding_model_dims=1536,
|
embedding_model_dims=1536,
|
||||||
compression_type="scalar",
|
compression_type="scalar",
|
||||||
use_float16=True
|
use_float16=True,
|
||||||
)
|
)
|
||||||
assert config.collection_name == "custom-index"
|
assert config.collection_name == "custom-index"
|
||||||
assert config.compression_type == "scalar"
|
assert config.compression_type == "scalar"
|
||||||
@@ -106,7 +105,7 @@ def test_config_validation_invalid_compression_type():
|
|||||||
service_name="test-service",
|
service_name="test-service",
|
||||||
api_key="test-api-key",
|
api_key="test-api-key",
|
||||||
embedding_model_dims=768,
|
embedding_model_dims=768,
|
||||||
compression_type="invalid-type" # Not a valid option
|
compression_type="invalid-type", # Not a valid option
|
||||||
)
|
)
|
||||||
assert "Invalid compression_type" in str(exc_info.value)
|
assert "Invalid compression_type" in str(exc_info.value)
|
||||||
|
|
||||||
@@ -118,7 +117,7 @@ def test_config_validation_deprecated_use_compression():
|
|||||||
service_name="test-service",
|
service_name="test-service",
|
||||||
api_key="test-api-key",
|
api_key="test-api-key",
|
||||||
embedding_model_dims=768,
|
embedding_model_dims=768,
|
||||||
use_compression=True # Deprecated parameter
|
use_compression=True, # Deprecated parameter
|
||||||
)
|
)
|
||||||
# Fix: Use a partial string match instead of exact match
|
# Fix: Use a partial string match instead of exact match
|
||||||
assert "use_compression" in str(exc_info.value)
|
assert "use_compression" in str(exc_info.value)
|
||||||
@@ -132,7 +131,7 @@ def test_config_validation_extra_fields():
|
|||||||
service_name="test-service",
|
service_name="test-service",
|
||||||
api_key="test-api-key",
|
api_key="test-api-key",
|
||||||
embedding_model_dims=768,
|
embedding_model_dims=768,
|
||||||
unknown_parameter="value" # Extra field
|
unknown_parameter="value", # Extra field
|
||||||
)
|
)
|
||||||
assert "Extra fields not allowed" in str(exc_info.value)
|
assert "Extra fields not allowed" in str(exc_info.value)
|
||||||
assert "unknown_parameter" in str(exc_info.value)
|
assert "unknown_parameter" in str(exc_info.value)
|
||||||
@@ -140,30 +139,28 @@ def test_config_validation_extra_fields():
|
|||||||
|
|
||||||
# --- Tests for AzureAISearch initialization ---
|
# --- Tests for AzureAISearch initialization ---
|
||||||
|
|
||||||
|
|
||||||
def test_initialization(mock_clients):
|
def test_initialization(mock_clients):
|
||||||
"""Test AzureAISearch initialization with different parameters."""
|
"""Test AzureAISearch initialization with different parameters."""
|
||||||
mock_search_client, mock_index_client, mock_azure_key_credential = mock_clients
|
mock_search_client, mock_index_client, mock_azure_key_credential = mock_clients
|
||||||
|
|
||||||
# Test with minimal parameters
|
# Test with minimal parameters
|
||||||
instance = AzureAISearch(
|
instance = AzureAISearch(
|
||||||
service_name="test-service",
|
service_name="test-service", collection_name="test-index", api_key="test-api-key", embedding_model_dims=768
|
||||||
collection_name="test-index",
|
|
||||||
api_key="test-api-key",
|
|
||||||
embedding_model_dims=768
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify initialization parameters
|
# Verify initialization parameters
|
||||||
assert instance.index_name == "test-index"
|
assert instance.index_name == "test-index"
|
||||||
assert instance.collection_name == "test-index"
|
assert instance.collection_name == "test-index"
|
||||||
assert instance.embedding_model_dims == 768
|
assert instance.embedding_model_dims == 768
|
||||||
assert instance.compression_type == "none" # Default when None is passed
|
assert instance.compression_type == "none" # Default when None is passed
|
||||||
assert instance.use_float16 is False
|
assert instance.use_float16 is False
|
||||||
|
|
||||||
# Verify client creation
|
# Verify client creation
|
||||||
mock_azure_key_credential.assert_called_with("test-api-key")
|
mock_azure_key_credential.assert_called_with("test-api-key")
|
||||||
assert "mem0" in mock_search_client._client._config.user_agent_policy.add_user_agent.call_args[0]
|
assert "mem0" in mock_search_client._client._config.user_agent_policy.add_user_agent.call_args[0]
|
||||||
assert "mem0" in mock_index_client._client._config.user_agent_policy.add_user_agent.call_args[0]
|
assert "mem0" in mock_index_client._client._config.user_agent_policy.add_user_agent.call_args[0]
|
||||||
|
|
||||||
# Verify index creation was called
|
# Verify index creation was called
|
||||||
mock_index_client.create_or_update_index.assert_called_once()
|
mock_index_client.create_or_update_index.assert_called_once()
|
||||||
|
|
||||||
@@ -171,75 +168,75 @@ def test_initialization(mock_clients):
|
|||||||
def test_initialization_with_compression_types(mock_clients):
|
def test_initialization_with_compression_types(mock_clients):
|
||||||
"""Test initialization with different compression types."""
|
"""Test initialization with different compression types."""
|
||||||
mock_search_client, mock_index_client, _ = mock_clients
|
mock_search_client, mock_index_client, _ = mock_clients
|
||||||
|
|
||||||
# Test with scalar compression
|
# Test with scalar compression
|
||||||
instance = AzureAISearch(
|
instance = AzureAISearch(
|
||||||
service_name="test-service",
|
service_name="test-service",
|
||||||
collection_name="scalar-index",
|
collection_name="scalar-index",
|
||||||
api_key="test-api-key",
|
api_key="test-api-key",
|
||||||
embedding_model_dims=768,
|
embedding_model_dims=768,
|
||||||
compression_type="scalar"
|
compression_type="scalar",
|
||||||
)
|
)
|
||||||
assert instance.compression_type == "scalar"
|
assert instance.compression_type == "scalar"
|
||||||
|
|
||||||
# Capture the index creation call
|
# Capture the index creation call
|
||||||
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
|
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
|
||||||
index = args[0]
|
index = args[0]
|
||||||
# Verify scalar compression was configured
|
# Verify scalar compression was configured
|
||||||
assert hasattr(index.vector_search, 'compressions')
|
assert hasattr(index.vector_search, "compressions")
|
||||||
assert len(index.vector_search.compressions) > 0
|
assert len(index.vector_search.compressions) > 0
|
||||||
assert "ScalarQuantizationCompression" in str(type(index.vector_search.compressions[0]))
|
assert "ScalarQuantizationCompression" in str(type(index.vector_search.compressions[0]))
|
||||||
|
|
||||||
# Test with binary compression
|
# Test with binary compression
|
||||||
instance = AzureAISearch(
|
instance = AzureAISearch(
|
||||||
service_name="test-service",
|
service_name="test-service",
|
||||||
collection_name="binary-index",
|
collection_name="binary-index",
|
||||||
api_key="test-api-key",
|
api_key="test-api-key",
|
||||||
embedding_model_dims=768,
|
embedding_model_dims=768,
|
||||||
compression_type="binary"
|
compression_type="binary",
|
||||||
)
|
)
|
||||||
assert instance.compression_type == "binary"
|
assert instance.compression_type == "binary"
|
||||||
|
|
||||||
# Capture the index creation call
|
# Capture the index creation call
|
||||||
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
|
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
|
||||||
index = args[0]
|
index = args[0]
|
||||||
# Verify binary compression was configured
|
# Verify binary compression was configured
|
||||||
assert hasattr(index.vector_search, 'compressions')
|
assert hasattr(index.vector_search, "compressions")
|
||||||
assert len(index.vector_search.compressions) > 0
|
assert len(index.vector_search.compressions) > 0
|
||||||
assert "BinaryQuantizationCompression" in str(type(index.vector_search.compressions[0]))
|
assert "BinaryQuantizationCompression" in str(type(index.vector_search.compressions[0]))
|
||||||
|
|
||||||
# Test with no compression
|
# Test with no compression
|
||||||
instance = AzureAISearch(
|
instance = AzureAISearch(
|
||||||
service_name="test-service",
|
service_name="test-service",
|
||||||
collection_name="no-compression-index",
|
collection_name="no-compression-index",
|
||||||
api_key="test-api-key",
|
api_key="test-api-key",
|
||||||
embedding_model_dims=768,
|
embedding_model_dims=768,
|
||||||
compression_type=None
|
compression_type=None,
|
||||||
)
|
)
|
||||||
assert instance.compression_type == "none"
|
assert instance.compression_type == "none"
|
||||||
|
|
||||||
# Capture the index creation call
|
# Capture the index creation call
|
||||||
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
|
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
|
||||||
index = args[0]
|
index = args[0]
|
||||||
# Verify no compression was configured
|
# Verify no compression was configured
|
||||||
assert hasattr(index.vector_search, 'compressions')
|
assert hasattr(index.vector_search, "compressions")
|
||||||
assert len(index.vector_search.compressions) == 0
|
assert len(index.vector_search.compressions) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_initialization_with_float_precision(mock_clients):
|
def test_initialization_with_float_precision(mock_clients):
|
||||||
"""Test initialization with different float precision settings."""
|
"""Test initialization with different float precision settings."""
|
||||||
mock_search_client, mock_index_client, _ = mock_clients
|
mock_search_client, mock_index_client, _ = mock_clients
|
||||||
|
|
||||||
# Test with half precision (float16)
|
# Test with half precision (float16)
|
||||||
instance = AzureAISearch(
|
instance = AzureAISearch(
|
||||||
service_name="test-service",
|
service_name="test-service",
|
||||||
collection_name="float16-index",
|
collection_name="float16-index",
|
||||||
api_key="test-api-key",
|
api_key="test-api-key",
|
||||||
embedding_model_dims=768,
|
embedding_model_dims=768,
|
||||||
use_float16=True
|
use_float16=True,
|
||||||
)
|
)
|
||||||
assert instance.use_float16 is True
|
assert instance.use_float16 is True
|
||||||
|
|
||||||
# Capture the index creation call
|
# Capture the index creation call
|
||||||
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
|
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
|
||||||
index = args[0]
|
index = args[0]
|
||||||
@@ -247,17 +244,17 @@ def test_initialization_with_float_precision(mock_clients):
|
|||||||
vector_field = next((f for f in index.fields if f.name == "vector"), None)
|
vector_field = next((f for f in index.fields if f.name == "vector"), None)
|
||||||
assert vector_field is not None
|
assert vector_field is not None
|
||||||
assert "Edm.Half" in vector_field.type
|
assert "Edm.Half" in vector_field.type
|
||||||
|
|
||||||
# Test with full precision (float32)
|
# Test with full precision (float32)
|
||||||
instance = AzureAISearch(
|
instance = AzureAISearch(
|
||||||
service_name="test-service",
|
service_name="test-service",
|
||||||
collection_name="float32-index",
|
collection_name="float32-index",
|
||||||
api_key="test-api-key",
|
api_key="test-api-key",
|
||||||
embedding_model_dims=768,
|
embedding_model_dims=768,
|
||||||
use_float16=False
|
use_float16=False,
|
||||||
)
|
)
|
||||||
assert instance.use_float16 is False
|
assert instance.use_float16 is False
|
||||||
|
|
||||||
# Capture the index creation call
|
# Capture the index creation call
|
||||||
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
|
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
|
||||||
index = args[0]
|
index = args[0]
|
||||||
@@ -269,21 +266,22 @@ def test_initialization_with_float_precision(mock_clients):
|
|||||||
|
|
||||||
# --- Tests for create_col method ---
|
# --- Tests for create_col method ---
|
||||||
|
|
||||||
|
|
||||||
def test_create_col(azure_ai_search_instance):
|
def test_create_col(azure_ai_search_instance):
|
||||||
"""Test the create_col method creates an index with the correct configuration."""
|
"""Test the create_col method creates an index with the correct configuration."""
|
||||||
instance, _, mock_index_client = azure_ai_search_instance
|
instance, _, mock_index_client = azure_ai_search_instance
|
||||||
|
|
||||||
# create_col is called during initialization, so we check the call that was already made
|
# create_col is called during initialization, so we check the call that was already made
|
||||||
mock_index_client.create_or_update_index.assert_called_once()
|
mock_index_client.create_or_update_index.assert_called_once()
|
||||||
|
|
||||||
# Verify the index configuration
|
# Verify the index configuration
|
||||||
args, _ = mock_index_client.create_or_update_index.call_args
|
args, _ = mock_index_client.create_or_update_index.call_args
|
||||||
index = args[0]
|
index = args[0]
|
||||||
|
|
||||||
# Check basic properties
|
# Check basic properties
|
||||||
assert index.name == "test-index"
|
assert index.name == "test-index"
|
||||||
assert len(index.fields) == 6 # id, user_id, run_id, agent_id, vector, payload
|
assert len(index.fields) == 6 # id, user_id, run_id, agent_id, vector, payload
|
||||||
|
|
||||||
# Check that required fields are present
|
# Check that required fields are present
|
||||||
field_names = [f.name for f in index.fields]
|
field_names = [f.name for f in index.fields]
|
||||||
assert "id" in field_names
|
assert "id" in field_names
|
||||||
@@ -292,22 +290,22 @@ def test_create_col(azure_ai_search_instance):
|
|||||||
assert "user_id" in field_names
|
assert "user_id" in field_names
|
||||||
assert "run_id" in field_names
|
assert "run_id" in field_names
|
||||||
assert "agent_id" in field_names
|
assert "agent_id" in field_names
|
||||||
|
|
||||||
# Check that id is the key field
|
# Check that id is the key field
|
||||||
id_field = next(f for f in index.fields if f.name == "id")
|
id_field = next(f for f in index.fields if f.name == "id")
|
||||||
assert id_field.key is True
|
assert id_field.key is True
|
||||||
|
|
||||||
# Check vector search configuration
|
# Check vector search configuration
|
||||||
assert index.vector_search is not None
|
assert index.vector_search is not None
|
||||||
assert len(index.vector_search.profiles) == 1
|
assert len(index.vector_search.profiles) == 1
|
||||||
assert index.vector_search.profiles[0].name == "my-vector-config"
|
assert index.vector_search.profiles[0].name == "my-vector-config"
|
||||||
assert index.vector_search.profiles[0].algorithm_configuration_name == "my-algorithms-config"
|
assert index.vector_search.profiles[0].algorithm_configuration_name == "my-algorithms-config"
|
||||||
|
|
||||||
# Check algorithms
|
# Check algorithms
|
||||||
assert len(index.vector_search.algorithms) == 1
|
assert len(index.vector_search.algorithms) == 1
|
||||||
assert index.vector_search.algorithms[0].name == "my-algorithms-config"
|
assert index.vector_search.algorithms[0].name == "my-algorithms-config"
|
||||||
assert "HnswAlgorithmConfiguration" in str(type(index.vector_search.algorithms[0]))
|
assert "HnswAlgorithmConfiguration" in str(type(index.vector_search.algorithms[0]))
|
||||||
|
|
||||||
# With binary compression and float16, we should have compression configuration
|
# With binary compression and float16, we should have compression configuration
|
||||||
assert len(index.vector_search.compressions) == 1
|
assert len(index.vector_search.compressions) == 1
|
||||||
assert index.vector_search.compressions[0].compression_name == "myCompression"
|
assert index.vector_search.compressions[0].compression_name == "myCompression"
|
||||||
@@ -317,24 +315,24 @@ def test_create_col(azure_ai_search_instance):
|
|||||||
def test_create_col_scalar_compression(mock_clients):
|
def test_create_col_scalar_compression(mock_clients):
|
||||||
"""Test creating a collection with scalar compression."""
|
"""Test creating a collection with scalar compression."""
|
||||||
mock_search_client, mock_index_client, _ = mock_clients
|
mock_search_client, mock_index_client, _ = mock_clients
|
||||||
|
|
||||||
AzureAISearch(
|
AzureAISearch(
|
||||||
service_name="test-service",
|
service_name="test-service",
|
||||||
collection_name="scalar-index",
|
collection_name="scalar-index",
|
||||||
api_key="test-api-key",
|
api_key="test-api-key",
|
||||||
embedding_model_dims=768,
|
embedding_model_dims=768,
|
||||||
compression_type="scalar"
|
compression_type="scalar",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify the index configuration
|
# Verify the index configuration
|
||||||
args, _ = mock_index_client.create_or_update_index.call_args
|
args, _ = mock_index_client.create_or_update_index.call_args
|
||||||
index = args[0]
|
index = args[0]
|
||||||
|
|
||||||
# Check compression configuration
|
# Check compression configuration
|
||||||
assert len(index.vector_search.compressions) == 1
|
assert len(index.vector_search.compressions) == 1
|
||||||
assert index.vector_search.compressions[0].compression_name == "myCompression"
|
assert index.vector_search.compressions[0].compression_name == "myCompression"
|
||||||
assert "ScalarQuantizationCompression" in str(type(index.vector_search.compressions[0]))
|
assert "ScalarQuantizationCompression" in str(type(index.vector_search.compressions[0]))
|
||||||
|
|
||||||
# Check profile references compression
|
# Check profile references compression
|
||||||
assert index.vector_search.profiles[0].compression_name == "myCompression"
|
assert index.vector_search.profiles[0].compression_name == "myCompression"
|
||||||
|
|
||||||
@@ -342,28 +340,29 @@ def test_create_col_scalar_compression(mock_clients):
|
|||||||
def test_create_col_no_compression(mock_clients):
|
def test_create_col_no_compression(mock_clients):
|
||||||
"""Test creating a collection with no compression."""
|
"""Test creating a collection with no compression."""
|
||||||
mock_search_client, mock_index_client, _ = mock_clients
|
mock_search_client, mock_index_client, _ = mock_clients
|
||||||
|
|
||||||
AzureAISearch(
|
AzureAISearch(
|
||||||
service_name="test-service",
|
service_name="test-service",
|
||||||
collection_name="no-compression-index",
|
collection_name="no-compression-index",
|
||||||
api_key="test-api-key",
|
api_key="test-api-key",
|
||||||
embedding_model_dims=768,
|
embedding_model_dims=768,
|
||||||
compression_type=None
|
compression_type=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify the index configuration
|
# Verify the index configuration
|
||||||
args, _ = mock_index_client.create_or_update_index.call_args
|
args, _ = mock_index_client.create_or_update_index.call_args
|
||||||
index = args[0]
|
index = args[0]
|
||||||
|
|
||||||
# Check compression configuration - should be empty
|
# Check compression configuration - should be empty
|
||||||
assert len(index.vector_search.compressions) == 0
|
assert len(index.vector_search.compressions) == 0
|
||||||
|
|
||||||
# Check profile doesn't reference compression
|
# Check profile doesn't reference compression
|
||||||
assert index.vector_search.profiles[0].compression_name is None
|
assert index.vector_search.profiles[0].compression_name is None
|
||||||
|
|
||||||
|
|
||||||
# --- Tests for insert method ---
|
# --- Tests for insert method ---
|
||||||
|
|
||||||
|
|
||||||
def test_insert_single(azure_ai_search_instance):
|
def test_insert_single(azure_ai_search_instance):
|
||||||
"""Test inserting a single vector."""
|
"""Test inserting a single vector."""
|
||||||
instance, mock_search_client, _ = azure_ai_search_instance
|
instance, mock_search_client, _ = azure_ai_search_instance
|
||||||
@@ -372,9 +371,7 @@ def test_insert_single(azure_ai_search_instance):
|
|||||||
ids = ["doc1"]
|
ids = ["doc1"]
|
||||||
|
|
||||||
# Fix: Include status_code: 201 in mock response
|
# Fix: Include status_code: 201 in mock response
|
||||||
mock_search_client.upload_documents.return_value = [
|
mock_search_client.upload_documents.return_value = [{"status": True, "id": "doc1", "status_code": 201}]
|
||||||
{"status": True, "id": "doc1", "status_code": 201}
|
|
||||||
]
|
|
||||||
|
|
||||||
instance.insert(vectors, payloads, ids)
|
instance.insert(vectors, payloads, ids)
|
||||||
|
|
||||||
@@ -396,35 +393,35 @@ def test_insert_single(azure_ai_search_instance):
|
|||||||
def test_insert_multiple(azure_ai_search_instance):
|
def test_insert_multiple(azure_ai_search_instance):
|
||||||
"""Test inserting multiple vectors in one call."""
|
"""Test inserting multiple vectors in one call."""
|
||||||
instance, mock_search_client, _ = azure_ai_search_instance
|
instance, mock_search_client, _ = azure_ai_search_instance
|
||||||
|
|
||||||
# Create multiple vectors
|
# Create multiple vectors
|
||||||
num_docs = 3
|
num_docs = 3
|
||||||
vectors = [[float(i)/10, float(i+1)/10, float(i+2)/10] for i in range(num_docs)]
|
vectors = [[float(i) / 10, float(i + 1) / 10, float(i + 2) / 10] for i in range(num_docs)]
|
||||||
payloads = [{"user_id": f"user{i}", "content": f"Test content {i}"} for i in range(num_docs)]
|
payloads = [{"user_id": f"user{i}", "content": f"Test content {i}"} for i in range(num_docs)]
|
||||||
ids = [f"doc{i}" for i in range(num_docs)]
|
ids = [f"doc{i}" for i in range(num_docs)]
|
||||||
|
|
||||||
# Configure mock to return success for all documents (fix: add status_code 201)
|
# Configure mock to return success for all documents (fix: add status_code 201)
|
||||||
mock_search_client.upload_documents.return_value = [
|
mock_search_client.upload_documents.return_value = [
|
||||||
{"status": True, "id": id_val, "status_code": 201} for id_val in ids
|
{"status": True, "id": id_val, "status_code": 201} for id_val in ids
|
||||||
]
|
]
|
||||||
|
|
||||||
# Insert the documents
|
# Insert the documents
|
||||||
instance.insert(vectors, payloads, ids)
|
instance.insert(vectors, payloads, ids)
|
||||||
|
|
||||||
# Verify upload_documents was called with correct documents
|
# Verify upload_documents was called with correct documents
|
||||||
mock_search_client.upload_documents.assert_called_once()
|
mock_search_client.upload_documents.assert_called_once()
|
||||||
args, _ = mock_search_client.upload_documents.call_args
|
args, _ = mock_search_client.upload_documents.call_args
|
||||||
documents = args[0]
|
documents = args[0]
|
||||||
|
|
||||||
# Verify all documents were included
|
# Verify all documents were included
|
||||||
assert len(documents) == num_docs
|
assert len(documents) == num_docs
|
||||||
|
|
||||||
# Check first document
|
# Check first document
|
||||||
assert documents[0]["id"] == "doc0"
|
assert documents[0]["id"] == "doc0"
|
||||||
assert documents[0]["vector"] == [0.0, 0.1, 0.2]
|
assert documents[0]["vector"] == [0.0, 0.1, 0.2]
|
||||||
assert documents[0]["payload"] == json.dumps(payloads[0])
|
assert documents[0]["payload"] == json.dumps(payloads[0])
|
||||||
assert documents[0]["user_id"] == "user0"
|
assert documents[0]["user_id"] == "user0"
|
||||||
|
|
||||||
# Check last document
|
# Check last document
|
||||||
assert documents[2]["id"] == "doc2"
|
assert documents[2]["id"] == "doc2"
|
||||||
assert documents[2]["vector"] == [0.2, 0.3, 0.4]
|
assert documents[2]["vector"] == [0.2, 0.3, 0.4]
|
||||||
@@ -437,9 +434,7 @@ def test_insert_with_error(azure_ai_search_instance):
|
|||||||
instance, mock_search_client, _ = azure_ai_search_instance
|
instance, mock_search_client, _ = azure_ai_search_instance
|
||||||
|
|
||||||
# Configure mock to return an error for one document
|
# Configure mock to return an error for one document
|
||||||
mock_search_client.upload_documents.return_value = [
|
mock_search_client.upload_documents.return_value = [{"status": False, "id": "doc1", "errorMessage": "Azure error"}]
|
||||||
{"status": False, "id": "doc1", "errorMessage": "Azure error"}
|
|
||||||
]
|
|
||||||
|
|
||||||
vectors = [[0.1, 0.2, 0.3]]
|
vectors = [[0.1, 0.2, 0.3]]
|
||||||
payloads = [{"user_id": "user1"}]
|
payloads = [{"user_id": "user1"}]
|
||||||
@@ -454,7 +449,7 @@ def test_insert_with_error(azure_ai_search_instance):
|
|||||||
# Configure mock to return mixed success/failure for multiple documents
|
# Configure mock to return mixed success/failure for multiple documents
|
||||||
mock_search_client.upload_documents.return_value = [
|
mock_search_client.upload_documents.return_value = [
|
||||||
{"status": True, "id": "doc1"}, # This should not cause failure
|
{"status": True, "id": "doc1"}, # This should not cause failure
|
||||||
{"status": False, "id": "doc2", "errorMessage": "Azure error"}
|
{"status": False, "id": "doc2", "errorMessage": "Azure error"},
|
||||||
]
|
]
|
||||||
|
|
||||||
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
@@ -465,8 +460,9 @@ def test_insert_with_error(azure_ai_search_instance):
|
|||||||
with pytest.raises(Exception) as exc_info:
|
with pytest.raises(Exception) as exc_info:
|
||||||
instance.insert(vectors, payloads, ids)
|
instance.insert(vectors, payloads, ids)
|
||||||
|
|
||||||
assert "Insert failed for document doc2" in str(exc_info.value) or \
|
assert "Insert failed for document doc2" in str(exc_info.value) or "Insert failed for document doc1" in str(
|
||||||
"Insert failed for document doc1" in str(exc_info.value)
|
exc_info.value
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_insert_with_missing_payload_fields(azure_ai_search_instance):
|
def test_insert_with_missing_payload_fields(azure_ai_search_instance):
|
||||||
@@ -500,23 +496,24 @@ def test_insert_with_missing_payload_fields(azure_ai_search_instance):
|
|||||||
def test_insert_with_http_error(azure_ai_search_instance):
|
def test_insert_with_http_error(azure_ai_search_instance):
|
||||||
"""Test insert when Azure client throws an HTTP error."""
|
"""Test insert when Azure client throws an HTTP error."""
|
||||||
instance, mock_search_client, _ = azure_ai_search_instance
|
instance, mock_search_client, _ = azure_ai_search_instance
|
||||||
|
|
||||||
# Configure mock to raise an HttpResponseError
|
# Configure mock to raise an HttpResponseError
|
||||||
mock_search_client.upload_documents.side_effect = HttpResponseError("Azure service error")
|
mock_search_client.upload_documents.side_effect = HttpResponseError("Azure service error")
|
||||||
|
|
||||||
vectors = [[0.1, 0.2, 0.3]]
|
vectors = [[0.1, 0.2, 0.3]]
|
||||||
payloads = [{"user_id": "user1"}]
|
payloads = [{"user_id": "user1"}]
|
||||||
ids = ["doc1"]
|
ids = ["doc1"]
|
||||||
|
|
||||||
# Insert should propagate the HTTP error
|
# Insert should propagate the HTTP error
|
||||||
with pytest.raises(HttpResponseError) as exc_info:
|
with pytest.raises(HttpResponseError) as exc_info:
|
||||||
instance.insert(vectors, payloads, ids)
|
instance.insert(vectors, payloads, ids)
|
||||||
|
|
||||||
assert "Azure service error" in str(exc_info.value)
|
assert "Azure service error" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
# --- Tests for search method ---
|
# --- Tests for search method ---
|
||||||
|
|
||||||
|
|
||||||
def test_search_basic(azure_ai_search_instance):
|
def test_search_basic(azure_ai_search_instance):
|
||||||
"""Test basic vector search without filters."""
|
"""Test basic vector search without filters."""
|
||||||
instance, mock_search_client, _ = azure_ai_search_instance
|
instance, mock_search_client, _ = azure_ai_search_instance
|
||||||
@@ -536,9 +533,7 @@ def test_search_basic(azure_ai_search_instance):
|
|||||||
# Search with a vector
|
# Search with a vector
|
||||||
query_text = "test query" # Add a query string
|
query_text = "test query" # Add a query string
|
||||||
query_vector = [0.1, 0.2, 0.3]
|
query_vector = [0.1, 0.2, 0.3]
|
||||||
results = instance.search(
|
results = instance.search(query_text, query_vector, limit=5) # Pass the query string
|
||||||
query_text, query_vector, limit=5
|
|
||||||
) # Pass the query string
|
|
||||||
|
|
||||||
# Verify search was called correctly
|
# Verify search was called correctly
|
||||||
mock_search_client.search.assert_called_once()
|
mock_search_client.search.assert_called_once()
|
||||||
|
|||||||
@@ -7,9 +7,7 @@ import dotenv
|
|||||||
try:
|
try:
|
||||||
from elasticsearch import Elasticsearch
|
from elasticsearch import Elasticsearch
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError("Elasticsearch requires extra dependencies. Install with `pip install elasticsearch`") from None
|
||||||
"Elasticsearch requires extra dependencies. Install with `pip install elasticsearch`"
|
|
||||||
) from None
|
|
||||||
|
|
||||||
from mem0.vector_stores.elasticsearch import ElasticsearchDB, OutputData
|
from mem0.vector_stores.elasticsearch import ElasticsearchDB, OutputData
|
||||||
|
|
||||||
@@ -19,20 +17,20 @@ class TestElasticsearchDB(unittest.TestCase):
|
|||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
# Load environment variables before any test
|
# Load environment variables before any test
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
# Save original environment variables
|
# Save original environment variables
|
||||||
cls.original_env = {
|
cls.original_env = {
|
||||||
'ES_URL': os.getenv('ES_URL', 'http://localhost:9200'),
|
"ES_URL": os.getenv("ES_URL", "http://localhost:9200"),
|
||||||
'ES_USERNAME': os.getenv('ES_USERNAME', 'test_user'),
|
"ES_USERNAME": os.getenv("ES_USERNAME", "test_user"),
|
||||||
'ES_PASSWORD': os.getenv('ES_PASSWORD', 'test_password'),
|
"ES_PASSWORD": os.getenv("ES_PASSWORD", "test_password"),
|
||||||
'ES_CLOUD_ID': os.getenv('ES_CLOUD_ID', 'test_cloud_id')
|
"ES_CLOUD_ID": os.getenv("ES_CLOUD_ID", "test_cloud_id"),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Set test environment variables
|
# Set test environment variables
|
||||||
os.environ['ES_URL'] = 'http://localhost'
|
os.environ["ES_URL"] = "http://localhost"
|
||||||
os.environ['ES_USERNAME'] = 'test_user'
|
os.environ["ES_USERNAME"] = "test_user"
|
||||||
os.environ['ES_PASSWORD'] = 'test_password'
|
os.environ["ES_PASSWORD"] = "test_password"
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
# Create a mock Elasticsearch client with proper attributes
|
# Create a mock Elasticsearch client with proper attributes
|
||||||
self.client_mock = MagicMock(spec=Elasticsearch)
|
self.client_mock = MagicMock(spec=Elasticsearch)
|
||||||
@@ -41,25 +39,25 @@ class TestElasticsearchDB(unittest.TestCase):
|
|||||||
self.client_mock.indices.create = MagicMock()
|
self.client_mock.indices.create = MagicMock()
|
||||||
self.client_mock.indices.delete = MagicMock()
|
self.client_mock.indices.delete = MagicMock()
|
||||||
self.client_mock.indices.get_alias = MagicMock()
|
self.client_mock.indices.get_alias = MagicMock()
|
||||||
|
|
||||||
# Start patches BEFORE creating ElasticsearchDB instance
|
# Start patches BEFORE creating ElasticsearchDB instance
|
||||||
patcher = patch('mem0.vector_stores.elasticsearch.Elasticsearch', return_value=self.client_mock)
|
patcher = patch("mem0.vector_stores.elasticsearch.Elasticsearch", return_value=self.client_mock)
|
||||||
self.mock_es = patcher.start()
|
self.mock_es = patcher.start()
|
||||||
self.addCleanup(patcher.stop)
|
self.addCleanup(patcher.stop)
|
||||||
|
|
||||||
# Initialize ElasticsearchDB with test config and auto_create_index=False
|
# Initialize ElasticsearchDB with test config and auto_create_index=False
|
||||||
self.es_db = ElasticsearchDB(
|
self.es_db = ElasticsearchDB(
|
||||||
host=os.getenv('ES_URL'),
|
host=os.getenv("ES_URL"),
|
||||||
port=9200,
|
port=9200,
|
||||||
collection_name="test_collection",
|
collection_name="test_collection",
|
||||||
embedding_model_dims=1536,
|
embedding_model_dims=1536,
|
||||||
user=os.getenv('ES_USERNAME'),
|
user=os.getenv("ES_USERNAME"),
|
||||||
password=os.getenv('ES_PASSWORD'),
|
password=os.getenv("ES_PASSWORD"),
|
||||||
verify_certs=False,
|
verify_certs=False,
|
||||||
use_ssl=False,
|
use_ssl=False,
|
||||||
auto_create_index=False # Disable auto creation for tests
|
auto_create_index=False, # Disable auto creation for tests
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reset mock counts after initialization
|
# Reset mock counts after initialization
|
||||||
self.client_mock.reset_mock()
|
self.client_mock.reset_mock()
|
||||||
|
|
||||||
@@ -80,15 +78,15 @@ class TestElasticsearchDB(unittest.TestCase):
|
|||||||
# Test when index doesn't exist
|
# Test when index doesn't exist
|
||||||
self.client_mock.indices.exists.return_value = False
|
self.client_mock.indices.exists.return_value = False
|
||||||
self.es_db.create_index()
|
self.es_db.create_index()
|
||||||
|
|
||||||
# Verify index creation was called with correct settings
|
# Verify index creation was called with correct settings
|
||||||
self.client_mock.indices.create.assert_called_once()
|
self.client_mock.indices.create.assert_called_once()
|
||||||
create_args = self.client_mock.indices.create.call_args[1]
|
create_args = self.client_mock.indices.create.call_args[1]
|
||||||
|
|
||||||
# Verify basic index settings
|
# Verify basic index settings
|
||||||
self.assertEqual(create_args["index"], "test_collection")
|
self.assertEqual(create_args["index"], "test_collection")
|
||||||
self.assertIn("mappings", create_args["body"])
|
self.assertIn("mappings", create_args["body"])
|
||||||
|
|
||||||
# Verify field mappings
|
# Verify field mappings
|
||||||
mappings = create_args["body"]["mappings"]["properties"]
|
mappings = create_args["body"]["mappings"]["properties"]
|
||||||
self.assertEqual(mappings["text"]["type"], "text")
|
self.assertEqual(mappings["text"]["type"], "text")
|
||||||
@@ -97,53 +95,53 @@ class TestElasticsearchDB(unittest.TestCase):
|
|||||||
self.assertEqual(mappings["vector"]["index"], True)
|
self.assertEqual(mappings["vector"]["index"], True)
|
||||||
self.assertEqual(mappings["vector"]["similarity"], "cosine")
|
self.assertEqual(mappings["vector"]["similarity"], "cosine")
|
||||||
self.assertEqual(mappings["metadata"]["type"], "object")
|
self.assertEqual(mappings["metadata"]["type"], "object")
|
||||||
|
|
||||||
# Reset mocks for next test
|
# Reset mocks for next test
|
||||||
self.client_mock.reset_mock()
|
self.client_mock.reset_mock()
|
||||||
|
|
||||||
# Test when index already exists
|
# Test when index already exists
|
||||||
self.client_mock.indices.exists.return_value = True
|
self.client_mock.indices.exists.return_value = True
|
||||||
self.es_db.create_index()
|
self.es_db.create_index()
|
||||||
|
|
||||||
# Verify create was not called when index exists
|
# Verify create was not called when index exists
|
||||||
self.client_mock.indices.create.assert_not_called()
|
self.client_mock.indices.create.assert_not_called()
|
||||||
|
|
||||||
def test_auto_create_index(self):
|
def test_auto_create_index(self):
|
||||||
# Reset mock
|
# Reset mock
|
||||||
self.client_mock.reset_mock()
|
self.client_mock.reset_mock()
|
||||||
|
|
||||||
# Test with auto_create_index=True
|
# Test with auto_create_index=True
|
||||||
ElasticsearchDB(
|
ElasticsearchDB(
|
||||||
host=os.getenv('ES_URL'),
|
host=os.getenv("ES_URL"),
|
||||||
port=9200,
|
port=9200,
|
||||||
collection_name="test_collection",
|
collection_name="test_collection",
|
||||||
embedding_model_dims=1536,
|
embedding_model_dims=1536,
|
||||||
user=os.getenv('ES_USERNAME'),
|
user=os.getenv("ES_USERNAME"),
|
||||||
password=os.getenv('ES_PASSWORD'),
|
password=os.getenv("ES_PASSWORD"),
|
||||||
verify_certs=False,
|
verify_certs=False,
|
||||||
use_ssl=False,
|
use_ssl=False,
|
||||||
auto_create_index=True
|
auto_create_index=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify create_index was called during initialization
|
# Verify create_index was called during initialization
|
||||||
self.client_mock.indices.exists.assert_called_once()
|
self.client_mock.indices.exists.assert_called_once()
|
||||||
|
|
||||||
# Reset mock
|
# Reset mock
|
||||||
self.client_mock.reset_mock()
|
self.client_mock.reset_mock()
|
||||||
|
|
||||||
# Test with auto_create_index=False
|
# Test with auto_create_index=False
|
||||||
ElasticsearchDB(
|
ElasticsearchDB(
|
||||||
host=os.getenv('ES_URL'),
|
host=os.getenv("ES_URL"),
|
||||||
port=9200,
|
port=9200,
|
||||||
collection_name="test_collection",
|
collection_name="test_collection",
|
||||||
embedding_model_dims=1536,
|
embedding_model_dims=1536,
|
||||||
user=os.getenv('ES_USERNAME'),
|
user=os.getenv("ES_USERNAME"),
|
||||||
password=os.getenv('ES_PASSWORD'),
|
password=os.getenv("ES_PASSWORD"),
|
||||||
verify_certs=False,
|
verify_certs=False,
|
||||||
use_ssl=False,
|
use_ssl=False,
|
||||||
auto_create_index=False
|
auto_create_index=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify create_index was not called during initialization
|
# Verify create_index was not called during initialization
|
||||||
self.client_mock.indices.exists.assert_not_called()
|
self.client_mock.indices.exists.assert_not_called()
|
||||||
|
|
||||||
@@ -152,17 +150,17 @@ class TestElasticsearchDB(unittest.TestCase):
|
|||||||
vectors = [[0.1] * 1536, [0.2] * 1536]
|
vectors = [[0.1] * 1536, [0.2] * 1536]
|
||||||
payloads = [{"key1": "value1"}, {"key2": "value2"}]
|
payloads = [{"key1": "value1"}, {"key2": "value2"}]
|
||||||
ids = ["id1", "id2"]
|
ids = ["id1", "id2"]
|
||||||
|
|
||||||
# Mock bulk operation
|
# Mock bulk operation
|
||||||
with patch('mem0.vector_stores.elasticsearch.bulk') as mock_bulk:
|
with patch("mem0.vector_stores.elasticsearch.bulk") as mock_bulk:
|
||||||
mock_bulk.return_value = (2, []) # Simulate successful bulk insert
|
mock_bulk.return_value = (2, []) # Simulate successful bulk insert
|
||||||
|
|
||||||
# Perform insert
|
# Perform insert
|
||||||
results = self.es_db.insert(vectors=vectors, payloads=payloads, ids=ids)
|
results = self.es_db.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||||
|
|
||||||
# Verify bulk was called
|
# Verify bulk was called
|
||||||
mock_bulk.assert_called_once()
|
mock_bulk.assert_called_once()
|
||||||
|
|
||||||
# Verify bulk actions format
|
# Verify bulk actions format
|
||||||
actions = mock_bulk.call_args[0][1]
|
actions = mock_bulk.call_args[0][1]
|
||||||
self.assertEqual(len(actions), 2)
|
self.assertEqual(len(actions), 2)
|
||||||
@@ -170,7 +168,7 @@ class TestElasticsearchDB(unittest.TestCase):
|
|||||||
self.assertEqual(actions[0]["_id"], "id1")
|
self.assertEqual(actions[0]["_id"], "id1")
|
||||||
self.assertEqual(actions[0]["_source"]["vector"], vectors[0])
|
self.assertEqual(actions[0]["_source"]["vector"], vectors[0])
|
||||||
self.assertEqual(actions[0]["_source"]["metadata"], payloads[0])
|
self.assertEqual(actions[0]["_source"]["metadata"], payloads[0])
|
||||||
|
|
||||||
# Verify returned objects
|
# Verify returned objects
|
||||||
self.assertEqual(len(results), 2)
|
self.assertEqual(len(results), 2)
|
||||||
self.assertIsInstance(results[0], OutputData)
|
self.assertIsInstance(results[0], OutputData)
|
||||||
@@ -182,14 +180,7 @@ class TestElasticsearchDB(unittest.TestCase):
|
|||||||
mock_response = {
|
mock_response = {
|
||||||
"hits": {
|
"hits": {
|
||||||
"hits": [
|
"hits": [
|
||||||
{
|
{"_id": "id1", "_score": 0.8, "_source": {"vector": [0.1] * 1536, "metadata": {"key1": "value1"}}}
|
||||||
"_id": "id1",
|
|
||||||
"_score": 0.8,
|
|
||||||
"_source": {
|
|
||||||
"vector": [0.1] * 1536,
|
|
||||||
"metadata": {"key1": "value1"}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -206,7 +197,7 @@ class TestElasticsearchDB(unittest.TestCase):
|
|||||||
# Verify search parameters
|
# Verify search parameters
|
||||||
self.assertEqual(search_args["index"], "test_collection")
|
self.assertEqual(search_args["index"], "test_collection")
|
||||||
body = search_args["body"]
|
body = search_args["body"]
|
||||||
|
|
||||||
# Verify KNN query structure
|
# Verify KNN query structure
|
||||||
self.assertIn("knn", body)
|
self.assertIn("knn", body)
|
||||||
self.assertEqual(body["knn"]["field"], "vector")
|
self.assertEqual(body["knn"]["field"], "vector")
|
||||||
@@ -235,29 +226,24 @@ class TestElasticsearchDB(unittest.TestCase):
|
|||||||
self.es_db.custom_search_query.assert_called_once_with(vectors, limit, filters)
|
self.es_db.custom_search_query.assert_called_once_with(vectors, limit, filters)
|
||||||
|
|
||||||
# Verify custom search query was used
|
# Verify custom search query was used
|
||||||
self.client_mock.search.assert_called_once_with(index=self.es_db.collection_name, body={"custom_key": "custom_value"})
|
self.client_mock.search.assert_called_once_with(
|
||||||
|
index=self.es_db.collection_name, body={"custom_key": "custom_value"}
|
||||||
|
)
|
||||||
|
|
||||||
def test_get(self):
|
def test_get(self):
|
||||||
# Mock get response with correct structure
|
# Mock get response with correct structure
|
||||||
mock_response = {
|
mock_response = {
|
||||||
"_id": "id1",
|
"_id": "id1",
|
||||||
"_source": {
|
"_source": {"vector": [0.1] * 1536, "metadata": {"key": "value"}, "text": "sample text"},
|
||||||
"vector": [0.1] * 1536,
|
|
||||||
"metadata": {"key": "value"},
|
|
||||||
"text": "sample text"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
self.client_mock.get.return_value = mock_response
|
self.client_mock.get.return_value = mock_response
|
||||||
|
|
||||||
# Perform get
|
# Perform get
|
||||||
result = self.es_db.get(vector_id="id1")
|
result = self.es_db.get(vector_id="id1")
|
||||||
|
|
||||||
# Verify get call
|
# Verify get call
|
||||||
self.client_mock.get.assert_called_once_with(
|
self.client_mock.get.assert_called_once_with(index="test_collection", id="id1")
|
||||||
index="test_collection",
|
|
||||||
id="id1"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify result
|
# Verify result
|
||||||
self.assertIsNotNone(result)
|
self.assertIsNotNone(result)
|
||||||
self.assertEqual(result.id, "id1")
|
self.assertEqual(result.id, "id1")
|
||||||
@@ -267,7 +253,7 @@ class TestElasticsearchDB(unittest.TestCase):
|
|||||||
def test_get_not_found(self):
|
def test_get_not_found(self):
|
||||||
# Mock get raising exception
|
# Mock get raising exception
|
||||||
self.client_mock.get.side_effect = Exception("Not found")
|
self.client_mock.get.side_effect = Exception("Not found")
|
||||||
|
|
||||||
# Verify get returns None when document not found
|
# Verify get returns None when document not found
|
||||||
result = self.es_db.get(vector_id="nonexistent")
|
result = self.es_db.get(vector_id="nonexistent")
|
||||||
self.assertIsNone(result)
|
self.assertIsNone(result)
|
||||||
@@ -277,33 +263,19 @@ class TestElasticsearchDB(unittest.TestCase):
|
|||||||
mock_response = {
|
mock_response = {
|
||||||
"hits": {
|
"hits": {
|
||||||
"hits": [
|
"hits": [
|
||||||
{
|
{"_id": "id1", "_source": {"vector": [0.1] * 1536, "metadata": {"key1": "value1"}}, "_score": 1.0},
|
||||||
"_id": "id1",
|
{"_id": "id2", "_source": {"vector": [0.2] * 1536, "metadata": {"key2": "value2"}}, "_score": 0.8},
|
||||||
"_source": {
|
|
||||||
"vector": [0.1] * 1536,
|
|
||||||
"metadata": {"key1": "value1"}
|
|
||||||
},
|
|
||||||
"_score": 1.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"_id": "id2",
|
|
||||||
"_source": {
|
|
||||||
"vector": [0.2] * 1536,
|
|
||||||
"metadata": {"key2": "value2"}
|
|
||||||
},
|
|
||||||
"_score": 0.8
|
|
||||||
}
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
self.client_mock.search.return_value = mock_response
|
self.client_mock.search.return_value = mock_response
|
||||||
|
|
||||||
# Perform list operation
|
# Perform list operation
|
||||||
results = self.es_db.list(limit=10)
|
results = self.es_db.list(limit=10)
|
||||||
|
|
||||||
# Verify search call
|
# Verify search call
|
||||||
self.client_mock.search.assert_called_once()
|
self.client_mock.search.assert_called_once()
|
||||||
|
|
||||||
# Verify results
|
# Verify results
|
||||||
self.assertEqual(len(results), 1) # Outer list
|
self.assertEqual(len(results), 1) # Outer list
|
||||||
self.assertEqual(len(results[0]), 2) # Inner list
|
self.assertEqual(len(results[0]), 2) # Inner list
|
||||||
@@ -316,30 +288,24 @@ class TestElasticsearchDB(unittest.TestCase):
|
|||||||
def test_delete(self):
|
def test_delete(self):
|
||||||
# Perform delete
|
# Perform delete
|
||||||
self.es_db.delete(vector_id="id1")
|
self.es_db.delete(vector_id="id1")
|
||||||
|
|
||||||
# Verify delete call
|
# Verify delete call
|
||||||
self.client_mock.delete.assert_called_once_with(
|
self.client_mock.delete.assert_called_once_with(index="test_collection", id="id1")
|
||||||
index="test_collection",
|
|
||||||
id="id1"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_list_cols(self):
|
def test_list_cols(self):
|
||||||
# Mock indices response
|
# Mock indices response
|
||||||
mock_indices = {"index1": {}, "index2": {}}
|
mock_indices = {"index1": {}, "index2": {}}
|
||||||
self.client_mock.indices.get_alias.return_value = mock_indices
|
self.client_mock.indices.get_alias.return_value = mock_indices
|
||||||
|
|
||||||
# Get collections
|
# Get collections
|
||||||
result = self.es_db.list_cols()
|
result = self.es_db.list_cols()
|
||||||
|
|
||||||
# Verify result
|
# Verify result
|
||||||
self.assertEqual(result, ["index1", "index2"])
|
self.assertEqual(result, ["index1", "index2"])
|
||||||
|
|
||||||
def test_delete_col(self):
|
def test_delete_col(self):
|
||||||
# Delete collection
|
# Delete collection
|
||||||
self.es_db.delete_col()
|
self.es_db.delete_col()
|
||||||
|
|
||||||
# Verify delete call
|
# Verify delete call
|
||||||
self.client_mock.indices.delete.assert_called_once_with(
|
self.client_mock.indices.delete.assert_called_once_with(index="test_collection")
|
||||||
index="test_collection"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|||||||
@@ -21,9 +21,9 @@ def mock_faiss_index():
|
|||||||
def faiss_instance(mock_faiss_index):
|
def faiss_instance(mock_faiss_index):
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
# Mock the faiss index creation
|
# Mock the faiss index creation
|
||||||
with patch('faiss.IndexFlatL2', return_value=mock_faiss_index):
|
with patch("faiss.IndexFlatL2", return_value=mock_faiss_index):
|
||||||
# Mock the faiss.write_index function
|
# Mock the faiss.write_index function
|
||||||
with patch('faiss.write_index'):
|
with patch("faiss.write_index"):
|
||||||
# Create a FAISS instance with a temporary directory
|
# Create a FAISS instance with a temporary directory
|
||||||
faiss_store = FAISS(
|
faiss_store = FAISS(
|
||||||
collection_name="test_collection",
|
collection_name="test_collection",
|
||||||
@@ -37,14 +37,14 @@ def faiss_instance(mock_faiss_index):
|
|||||||
|
|
||||||
def test_create_col(faiss_instance, mock_faiss_index):
|
def test_create_col(faiss_instance, mock_faiss_index):
|
||||||
# Test creating a collection with euclidean distance
|
# Test creating a collection with euclidean distance
|
||||||
with patch('faiss.IndexFlatL2', return_value=mock_faiss_index) as mock_index_flat_l2:
|
with patch("faiss.IndexFlatL2", return_value=mock_faiss_index) as mock_index_flat_l2:
|
||||||
with patch('faiss.write_index'):
|
with patch("faiss.write_index"):
|
||||||
faiss_instance.create_col(name="new_collection")
|
faiss_instance.create_col(name="new_collection")
|
||||||
mock_index_flat_l2.assert_called_once_with(faiss_instance.embedding_model_dims)
|
mock_index_flat_l2.assert_called_once_with(faiss_instance.embedding_model_dims)
|
||||||
|
|
||||||
# Test creating a collection with inner product distance
|
# Test creating a collection with inner product distance
|
||||||
with patch('faiss.IndexFlatIP', return_value=mock_faiss_index) as mock_index_flat_ip:
|
with patch("faiss.IndexFlatIP", return_value=mock_faiss_index) as mock_index_flat_ip:
|
||||||
with patch('faiss.write_index'):
|
with patch("faiss.write_index"):
|
||||||
faiss_instance.create_col(name="new_collection", distance="inner_product")
|
faiss_instance.create_col(name="new_collection", distance="inner_product")
|
||||||
mock_index_flat_ip.assert_called_once_with(faiss_instance.embedding_model_dims)
|
mock_index_flat_ip.assert_called_once_with(faiss_instance.embedding_model_dims)
|
||||||
|
|
||||||
@@ -54,21 +54,21 @@ def test_insert(faiss_instance, mock_faiss_index):
|
|||||||
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
payloads = [{"name": "vector1"}, {"name": "vector2"}]
|
payloads = [{"name": "vector1"}, {"name": "vector2"}]
|
||||||
ids = ["id1", "id2"]
|
ids = ["id1", "id2"]
|
||||||
|
|
||||||
# Mock the numpy array conversion
|
# Mock the numpy array conversion
|
||||||
with patch('numpy.array', return_value=np.array(vectors, dtype=np.float32)) as mock_np_array:
|
with patch("numpy.array", return_value=np.array(vectors, dtype=np.float32)) as mock_np_array:
|
||||||
# Mock index.add
|
# Mock index.add
|
||||||
mock_faiss_index.add.return_value = None
|
mock_faiss_index.add.return_value = None
|
||||||
|
|
||||||
# Call insert
|
# Call insert
|
||||||
faiss_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
|
faiss_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||||
|
|
||||||
# Verify numpy.array was called
|
# Verify numpy.array was called
|
||||||
mock_np_array.assert_called_once_with(vectors, dtype=np.float32)
|
mock_np_array.assert_called_once_with(vectors, dtype=np.float32)
|
||||||
|
|
||||||
# Verify index.add was called
|
# Verify index.add was called
|
||||||
mock_faiss_index.add.assert_called_once()
|
mock_faiss_index.add.assert_called_once()
|
||||||
|
|
||||||
# Verify docstore and index_to_id were updated
|
# Verify docstore and index_to_id were updated
|
||||||
assert faiss_instance.docstore["id1"] == {"name": "vector1"}
|
assert faiss_instance.docstore["id1"] == {"name": "vector1"}
|
||||||
assert faiss_instance.docstore["id2"] == {"name": "vector2"}
|
assert faiss_instance.docstore["id2"] == {"name": "vector2"}
|
||||||
@@ -79,39 +79,36 @@ def test_insert(faiss_instance, mock_faiss_index):
|
|||||||
def test_search(faiss_instance, mock_faiss_index):
|
def test_search(faiss_instance, mock_faiss_index):
|
||||||
# Prepare test data
|
# Prepare test data
|
||||||
query_vector = [0.1, 0.2, 0.3]
|
query_vector = [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
# Setup the docstore and index_to_id mapping
|
# Setup the docstore and index_to_id mapping
|
||||||
faiss_instance.docstore = {
|
faiss_instance.docstore = {"id1": {"name": "vector1"}, "id2": {"name": "vector2"}}
|
||||||
"id1": {"name": "vector1"},
|
|
||||||
"id2": {"name": "vector2"}
|
|
||||||
}
|
|
||||||
faiss_instance.index_to_id = {0: "id1", 1: "id2"}
|
faiss_instance.index_to_id = {0: "id1", 1: "id2"}
|
||||||
|
|
||||||
# First, create the mock for the search return values
|
# First, create the mock for the search return values
|
||||||
search_scores = np.array([[0.9, 0.8]])
|
search_scores = np.array([[0.9, 0.8]])
|
||||||
search_indices = np.array([[0, 1]])
|
search_indices = np.array([[0, 1]])
|
||||||
mock_faiss_index.search.return_value = (search_scores, search_indices)
|
mock_faiss_index.search.return_value = (search_scores, search_indices)
|
||||||
|
|
||||||
# Then patch numpy.array only for the query vector conversion
|
# Then patch numpy.array only for the query vector conversion
|
||||||
with patch('numpy.array') as mock_np_array:
|
with patch("numpy.array") as mock_np_array:
|
||||||
mock_np_array.return_value = np.array(query_vector, dtype=np.float32)
|
mock_np_array.return_value = np.array(query_vector, dtype=np.float32)
|
||||||
|
|
||||||
# Then patch _parse_output to return the expected results
|
# Then patch _parse_output to return the expected results
|
||||||
expected_results = [
|
expected_results = [
|
||||||
OutputData(id="id1", score=0.9, payload={"name": "vector1"}),
|
OutputData(id="id1", score=0.9, payload={"name": "vector1"}),
|
||||||
OutputData(id="id2", score=0.8, payload={"name": "vector2"})
|
OutputData(id="id2", score=0.8, payload={"name": "vector2"}),
|
||||||
]
|
]
|
||||||
|
|
||||||
with patch.object(faiss_instance, '_parse_output', return_value=expected_results):
|
with patch.object(faiss_instance, "_parse_output", return_value=expected_results):
|
||||||
# Call search
|
# Call search
|
||||||
results = faiss_instance.search(query="test query", vectors=query_vector, limit=2)
|
results = faiss_instance.search(query="test query", vectors=query_vector, limit=2)
|
||||||
|
|
||||||
# Verify numpy.array was called (but we don't check exact call arguments since it's complex)
|
# Verify numpy.array was called (but we don't check exact call arguments since it's complex)
|
||||||
assert mock_np_array.called
|
assert mock_np_array.called
|
||||||
|
|
||||||
# Verify index.search was called
|
# Verify index.search was called
|
||||||
mock_faiss_index.search.assert_called_once()
|
mock_faiss_index.search.assert_called_once()
|
||||||
|
|
||||||
# Verify results
|
# Verify results
|
||||||
assert len(results) == 2
|
assert len(results) == 2
|
||||||
assert results[0].id == "id1"
|
assert results[0].id == "id1"
|
||||||
@@ -125,47 +122,41 @@ def test_search(faiss_instance, mock_faiss_index):
|
|||||||
def test_search_with_filters(faiss_instance, mock_faiss_index):
|
def test_search_with_filters(faiss_instance, mock_faiss_index):
|
||||||
# Prepare test data
|
# Prepare test data
|
||||||
query_vector = [0.1, 0.2, 0.3]
|
query_vector = [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
# Setup the docstore and index_to_id mapping
|
# Setup the docstore and index_to_id mapping
|
||||||
faiss_instance.docstore = {
|
faiss_instance.docstore = {"id1": {"name": "vector1", "category": "A"}, "id2": {"name": "vector2", "category": "B"}}
|
||||||
"id1": {"name": "vector1", "category": "A"},
|
|
||||||
"id2": {"name": "vector2", "category": "B"}
|
|
||||||
}
|
|
||||||
faiss_instance.index_to_id = {0: "id1", 1: "id2"}
|
faiss_instance.index_to_id = {0: "id1", 1: "id2"}
|
||||||
|
|
||||||
# First set up the search return values
|
# First set up the search return values
|
||||||
search_scores = np.array([[0.9, 0.8]])
|
search_scores = np.array([[0.9, 0.8]])
|
||||||
search_indices = np.array([[0, 1]])
|
search_indices = np.array([[0, 1]])
|
||||||
mock_faiss_index.search.return_value = (search_scores, search_indices)
|
mock_faiss_index.search.return_value = (search_scores, search_indices)
|
||||||
|
|
||||||
# Patch numpy.array for query vector conversion
|
# Patch numpy.array for query vector conversion
|
||||||
with patch('numpy.array') as mock_np_array:
|
with patch("numpy.array") as mock_np_array:
|
||||||
mock_np_array.return_value = np.array(query_vector, dtype=np.float32)
|
mock_np_array.return_value = np.array(query_vector, dtype=np.float32)
|
||||||
|
|
||||||
# Directly mock the _parse_output method to return our expected values
|
# Directly mock the _parse_output method to return our expected values
|
||||||
# We're simulating that _parse_output filters to just the first result
|
# We're simulating that _parse_output filters to just the first result
|
||||||
all_results = [
|
all_results = [
|
||||||
OutputData(id="id1", score=0.9, payload={"name": "vector1", "category": "A"}),
|
OutputData(id="id1", score=0.9, payload={"name": "vector1", "category": "A"}),
|
||||||
OutputData(id="id2", score=0.8, payload={"name": "vector2", "category": "B"})
|
OutputData(id="id2", score=0.8, payload={"name": "vector2", "category": "B"}),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Replace the _apply_filters method to handle our test case
|
# Replace the _apply_filters method to handle our test case
|
||||||
with patch.object(faiss_instance, '_parse_output', return_value=all_results):
|
with patch.object(faiss_instance, "_parse_output", return_value=all_results):
|
||||||
with patch.object(faiss_instance, '_apply_filters', side_effect=lambda p, f: p.get("category") == "A"):
|
with patch.object(faiss_instance, "_apply_filters", side_effect=lambda p, f: p.get("category") == "A"):
|
||||||
# Call search with filters
|
# Call search with filters
|
||||||
results = faiss_instance.search(
|
results = faiss_instance.search(
|
||||||
query="test query",
|
query="test query", vectors=query_vector, limit=2, filters={"category": "A"}
|
||||||
vectors=query_vector,
|
|
||||||
limit=2,
|
|
||||||
filters={"category": "A"}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify numpy.array was called
|
# Verify numpy.array was called
|
||||||
assert mock_np_array.called
|
assert mock_np_array.called
|
||||||
|
|
||||||
# Verify index.search was called
|
# Verify index.search was called
|
||||||
mock_faiss_index.search.assert_called_once()
|
mock_faiss_index.search.assert_called_once()
|
||||||
|
|
||||||
# Verify filtered results - since we've mocked everything,
|
# Verify filtered results - since we've mocked everything,
|
||||||
# we should get just the result we want
|
# we should get just the result we want
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
@@ -176,15 +167,12 @@ def test_search_with_filters(faiss_instance, mock_faiss_index):
|
|||||||
|
|
||||||
def test_delete(faiss_instance):
|
def test_delete(faiss_instance):
|
||||||
# Setup the docstore and index_to_id mapping
|
# Setup the docstore and index_to_id mapping
|
||||||
faiss_instance.docstore = {
|
faiss_instance.docstore = {"id1": {"name": "vector1"}, "id2": {"name": "vector2"}}
|
||||||
"id1": {"name": "vector1"},
|
|
||||||
"id2": {"name": "vector2"}
|
|
||||||
}
|
|
||||||
faiss_instance.index_to_id = {0: "id1", 1: "id2"}
|
faiss_instance.index_to_id = {0: "id1", 1: "id2"}
|
||||||
|
|
||||||
# Call delete
|
# Call delete
|
||||||
faiss_instance.delete(vector_id="id1")
|
faiss_instance.delete(vector_id="id1")
|
||||||
|
|
||||||
# Verify the vector was removed from docstore and index_to_id
|
# Verify the vector was removed from docstore and index_to_id
|
||||||
assert "id1" not in faiss_instance.docstore
|
assert "id1" not in faiss_instance.docstore
|
||||||
assert 0 not in faiss_instance.index_to_id
|
assert 0 not in faiss_instance.index_to_id
|
||||||
@@ -194,23 +182,20 @@ def test_delete(faiss_instance):
|
|||||||
|
|
||||||
def test_update(faiss_instance, mock_faiss_index):
|
def test_update(faiss_instance, mock_faiss_index):
|
||||||
# Setup the docstore and index_to_id mapping
|
# Setup the docstore and index_to_id mapping
|
||||||
faiss_instance.docstore = {
|
faiss_instance.docstore = {"id1": {"name": "vector1"}, "id2": {"name": "vector2"}}
|
||||||
"id1": {"name": "vector1"},
|
|
||||||
"id2": {"name": "vector2"}
|
|
||||||
}
|
|
||||||
faiss_instance.index_to_id = {0: "id1", 1: "id2"}
|
faiss_instance.index_to_id = {0: "id1", 1: "id2"}
|
||||||
|
|
||||||
# Test updating payload only
|
# Test updating payload only
|
||||||
faiss_instance.update(vector_id="id1", payload={"name": "updated_vector1"})
|
faiss_instance.update(vector_id="id1", payload={"name": "updated_vector1"})
|
||||||
assert faiss_instance.docstore["id1"] == {"name": "updated_vector1"}
|
assert faiss_instance.docstore["id1"] == {"name": "updated_vector1"}
|
||||||
|
|
||||||
# Test updating vector
|
# Test updating vector
|
||||||
# This requires mocking the delete and insert methods
|
# This requires mocking the delete and insert methods
|
||||||
with patch.object(faiss_instance, 'delete') as mock_delete:
|
with patch.object(faiss_instance, "delete") as mock_delete:
|
||||||
with patch.object(faiss_instance, 'insert') as mock_insert:
|
with patch.object(faiss_instance, "insert") as mock_insert:
|
||||||
new_vector = [0.7, 0.8, 0.9]
|
new_vector = [0.7, 0.8, 0.9]
|
||||||
faiss_instance.update(vector_id="id2", vector=new_vector)
|
faiss_instance.update(vector_id="id2", vector=new_vector)
|
||||||
|
|
||||||
# Verify delete and insert were called
|
# Verify delete and insert were called
|
||||||
# Match the actual call signature (positional arg instead of keyword)
|
# Match the actual call signature (positional arg instead of keyword)
|
||||||
mock_delete.assert_called_once_with("id2")
|
mock_delete.assert_called_once_with("id2")
|
||||||
@@ -219,17 +204,14 @@ def test_update(faiss_instance, mock_faiss_index):
|
|||||||
|
|
||||||
def test_get(faiss_instance):
|
def test_get(faiss_instance):
|
||||||
# Setup the docstore
|
# Setup the docstore
|
||||||
faiss_instance.docstore = {
|
faiss_instance.docstore = {"id1": {"name": "vector1"}, "id2": {"name": "vector2"}}
|
||||||
"id1": {"name": "vector1"},
|
|
||||||
"id2": {"name": "vector2"}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Test getting an existing vector
|
# Test getting an existing vector
|
||||||
result = faiss_instance.get(vector_id="id1")
|
result = faiss_instance.get(vector_id="id1")
|
||||||
assert result.id == "id1"
|
assert result.id == "id1"
|
||||||
assert result.payload == {"name": "vector1"}
|
assert result.payload == {"name": "vector1"}
|
||||||
assert result.score is None
|
assert result.score is None
|
||||||
|
|
||||||
# Test getting a non-existent vector
|
# Test getting a non-existent vector
|
||||||
result = faiss_instance.get(vector_id="id3")
|
result = faiss_instance.get(vector_id="id3")
|
||||||
assert result is None
|
assert result is None
|
||||||
@@ -240,18 +222,18 @@ def test_list(faiss_instance):
|
|||||||
faiss_instance.docstore = {
|
faiss_instance.docstore = {
|
||||||
"id1": {"name": "vector1", "category": "A"},
|
"id1": {"name": "vector1", "category": "A"},
|
||||||
"id2": {"name": "vector2", "category": "B"},
|
"id2": {"name": "vector2", "category": "B"},
|
||||||
"id3": {"name": "vector3", "category": "A"}
|
"id3": {"name": "vector3", "category": "A"},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Test listing all vectors
|
# Test listing all vectors
|
||||||
results = faiss_instance.list()
|
results = faiss_instance.list()
|
||||||
# Fix the expected result - the list method returns a list of lists
|
# Fix the expected result - the list method returns a list of lists
|
||||||
assert len(results[0]) == 3
|
assert len(results[0]) == 3
|
||||||
|
|
||||||
# Test listing with a limit
|
# Test listing with a limit
|
||||||
results = faiss_instance.list(limit=2)
|
results = faiss_instance.list(limit=2)
|
||||||
assert len(results[0]) == 2
|
assert len(results[0]) == 2
|
||||||
|
|
||||||
# Test listing with filters
|
# Test listing with filters
|
||||||
results = faiss_instance.list(filters={"category": "A"})
|
results = faiss_instance.list(filters={"category": "A"})
|
||||||
assert len(results[0]) == 2
|
assert len(results[0]) == 2
|
||||||
@@ -263,10 +245,10 @@ def test_col_info(faiss_instance, mock_faiss_index):
|
|||||||
# Mock index attributes
|
# Mock index attributes
|
||||||
mock_faiss_index.ntotal = 5
|
mock_faiss_index.ntotal = 5
|
||||||
mock_faiss_index.d = 128
|
mock_faiss_index.d = 128
|
||||||
|
|
||||||
# Get collection info
|
# Get collection info
|
||||||
info = faiss_instance.col_info()
|
info = faiss_instance.col_info()
|
||||||
|
|
||||||
# Verify the returned info
|
# Verify the returned info
|
||||||
assert info["name"] == "test_collection"
|
assert info["name"] == "test_collection"
|
||||||
assert info["count"] == 5
|
assert info["count"] == 5
|
||||||
@@ -276,14 +258,14 @@ def test_col_info(faiss_instance, mock_faiss_index):
|
|||||||
|
|
||||||
def test_delete_col(faiss_instance):
|
def test_delete_col(faiss_instance):
|
||||||
# Mock the os.remove function
|
# Mock the os.remove function
|
||||||
with patch('os.remove') as mock_remove:
|
with patch("os.remove") as mock_remove:
|
||||||
with patch('os.path.exists', return_value=True):
|
with patch("os.path.exists", return_value=True):
|
||||||
# Call delete_col
|
# Call delete_col
|
||||||
faiss_instance.delete_col()
|
faiss_instance.delete_col()
|
||||||
|
|
||||||
# Verify os.remove was called twice (for index and docstore files)
|
# Verify os.remove was called twice (for index and docstore files)
|
||||||
assert mock_remove.call_count == 2
|
assert mock_remove.call_count == 2
|
||||||
|
|
||||||
# Verify the internal state was reset
|
# Verify the internal state was reset
|
||||||
assert faiss_instance.index is None
|
assert faiss_instance.index is None
|
||||||
assert faiss_instance.docstore == {}
|
assert faiss_instance.docstore == {}
|
||||||
@@ -293,17 +275,17 @@ def test_delete_col(faiss_instance):
|
|||||||
def test_normalize_L2(faiss_instance, mock_faiss_index):
|
def test_normalize_L2(faiss_instance, mock_faiss_index):
|
||||||
# Setup a FAISS instance with normalize_L2=True
|
# Setup a FAISS instance with normalize_L2=True
|
||||||
faiss_instance.normalize_L2 = True
|
faiss_instance.normalize_L2 = True
|
||||||
|
|
||||||
# Prepare test data
|
# Prepare test data
|
||||||
vectors = [[0.1, 0.2, 0.3]]
|
vectors = [[0.1, 0.2, 0.3]]
|
||||||
|
|
||||||
# Mock numpy array conversion
|
# Mock numpy array conversion
|
||||||
# Mock numpy array conversion
|
# Mock numpy array conversion
|
||||||
with patch('numpy.array', return_value=np.array(vectors, dtype=np.float32)):
|
with patch("numpy.array", return_value=np.array(vectors, dtype=np.float32)):
|
||||||
# Mock faiss.normalize_L2
|
# Mock faiss.normalize_L2
|
||||||
with patch('faiss.normalize_L2') as mock_normalize:
|
with patch("faiss.normalize_L2") as mock_normalize:
|
||||||
# Call insert
|
# Call insert
|
||||||
faiss_instance.insert(vectors=vectors, ids=["id1"])
|
faiss_instance.insert(vectors=vectors, ids=["id1"])
|
||||||
|
|
||||||
# Verify faiss.normalize_L2 was called
|
# Verify faiss.normalize_L2 was called
|
||||||
mock_normalize.assert_called_once()
|
mock_normalize.assert_called_once()
|
||||||
|
|||||||
@@ -11,11 +11,13 @@ def mock_langchain_client():
|
|||||||
with patch("langchain_community.vectorstores.VectorStore") as mock_client:
|
with patch("langchain_community.vectorstores.VectorStore") as mock_client:
|
||||||
yield mock_client
|
yield mock_client
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def langchain_instance(mock_langchain_client):
|
def langchain_instance(mock_langchain_client):
|
||||||
mock_client = Mock(spec=VectorStore)
|
mock_client = Mock(spec=VectorStore)
|
||||||
return Langchain(client=mock_client, collection_name="test_collection")
|
return Langchain(client=mock_client, collection_name="test_collection")
|
||||||
|
|
||||||
|
|
||||||
def test_insert_vectors(langchain_instance):
|
def test_insert_vectors(langchain_instance):
|
||||||
# Test data
|
# Test data
|
||||||
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
@@ -25,48 +27,31 @@ def test_insert_vectors(langchain_instance):
|
|||||||
# Test with add_embeddings method
|
# Test with add_embeddings method
|
||||||
langchain_instance.client.add_embeddings = Mock()
|
langchain_instance.client.add_embeddings = Mock()
|
||||||
langchain_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
|
langchain_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||||
langchain_instance.client.add_embeddings.assert_called_once_with(
|
langchain_instance.client.add_embeddings.assert_called_once_with(embeddings=vectors, metadatas=payloads, ids=ids)
|
||||||
embeddings=vectors,
|
|
||||||
metadatas=payloads,
|
|
||||||
ids=ids
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test with add_texts method
|
# Test with add_texts method
|
||||||
delattr(langchain_instance.client, "add_embeddings") # Remove attribute completely
|
delattr(langchain_instance.client, "add_embeddings") # Remove attribute completely
|
||||||
langchain_instance.client.add_texts = Mock()
|
langchain_instance.client.add_texts = Mock()
|
||||||
langchain_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
|
langchain_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||||
langchain_instance.client.add_texts.assert_called_once_with(
|
langchain_instance.client.add_texts.assert_called_once_with(texts=["text1", "text2"], metadatas=payloads, ids=ids)
|
||||||
texts=["text1", "text2"],
|
|
||||||
metadatas=payloads,
|
|
||||||
ids=ids
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test with empty payloads
|
# Test with empty payloads
|
||||||
langchain_instance.client.add_texts.reset_mock()
|
langchain_instance.client.add_texts.reset_mock()
|
||||||
langchain_instance.insert(vectors=vectors, payloads=None, ids=ids)
|
langchain_instance.insert(vectors=vectors, payloads=None, ids=ids)
|
||||||
langchain_instance.client.add_texts.assert_called_once_with(
|
langchain_instance.client.add_texts.assert_called_once_with(texts=["", ""], metadatas=None, ids=ids)
|
||||||
texts=["", ""],
|
|
||||||
metadatas=None,
|
|
||||||
ids=ids
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_search_vectors(langchain_instance):
|
def test_search_vectors(langchain_instance):
|
||||||
# Mock search results
|
# Mock search results
|
||||||
mock_docs = [
|
mock_docs = [Mock(metadata={"name": "vector1"}, id="id1"), Mock(metadata={"name": "vector2"}, id="id2")]
|
||||||
Mock(metadata={"name": "vector1"}, id="id1"),
|
|
||||||
Mock(metadata={"name": "vector2"}, id="id2")
|
|
||||||
]
|
|
||||||
langchain_instance.client.similarity_search_by_vector.return_value = mock_docs
|
langchain_instance.client.similarity_search_by_vector.return_value = mock_docs
|
||||||
|
|
||||||
# Test search without filters
|
# Test search without filters
|
||||||
vectors = [[0.1, 0.2, 0.3]]
|
vectors = [[0.1, 0.2, 0.3]]
|
||||||
results = langchain_instance.search(query="", vectors=vectors, limit=2)
|
results = langchain_instance.search(query="", vectors=vectors, limit=2)
|
||||||
|
|
||||||
langchain_instance.client.similarity_search_by_vector.assert_called_once_with(
|
langchain_instance.client.similarity_search_by_vector.assert_called_once_with(embedding=vectors, k=2)
|
||||||
embedding=vectors,
|
|
||||||
k=2
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(results) == 2
|
assert len(results) == 2
|
||||||
assert results[0].id == "id1"
|
assert results[0].id == "id1"
|
||||||
assert results[0].payload == {"name": "vector1"}
|
assert results[0].payload == {"name": "vector1"}
|
||||||
@@ -76,11 +61,8 @@ def test_search_vectors(langchain_instance):
|
|||||||
# Test search with filters
|
# Test search with filters
|
||||||
filters = {"name": "vector1"}
|
filters = {"name": "vector1"}
|
||||||
langchain_instance.search(query="", vectors=vectors, limit=2, filters=filters)
|
langchain_instance.search(query="", vectors=vectors, limit=2, filters=filters)
|
||||||
langchain_instance.client.similarity_search_by_vector.assert_called_with(
|
langchain_instance.client.similarity_search_by_vector.assert_called_with(embedding=vectors, k=2, filter=filters)
|
||||||
embedding=vectors,
|
|
||||||
k=2,
|
|
||||||
filter=filters
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_get_vector(langchain_instance):
|
def test_get_vector(langchain_instance):
|
||||||
# Mock get result
|
# Mock get result
|
||||||
@@ -90,7 +72,7 @@ def test_get_vector(langchain_instance):
|
|||||||
# Test get existing vector
|
# Test get existing vector
|
||||||
result = langchain_instance.get("id1")
|
result = langchain_instance.get("id1")
|
||||||
langchain_instance.client.get_by_ids.assert_called_once_with(["id1"])
|
langchain_instance.client.get_by_ids.assert_called_once_with(["id1"])
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.id == "id1"
|
assert result.id == "id1"
|
||||||
assert result.payload == {"name": "vector1"}
|
assert result.payload == {"name": "vector1"}
|
||||||
|
|||||||
@@ -8,9 +8,7 @@ import pytest
|
|||||||
try:
|
try:
|
||||||
from opensearchpy import AWSV4SignerAuth, OpenSearch
|
from opensearchpy import AWSV4SignerAuth, OpenSearch
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError("OpenSearch requires extra dependencies. Install with `pip install opensearch-py`") from None
|
||||||
"OpenSearch requires extra dependencies. Install with `pip install opensearch-py`"
|
|
||||||
) from None
|
|
||||||
|
|
||||||
from mem0.vector_stores.opensearch import OpenSearchDB
|
from mem0.vector_stores.opensearch import OpenSearchDB
|
||||||
|
|
||||||
@@ -20,13 +18,13 @@ class TestOpenSearchDB(unittest.TestCase):
|
|||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
cls.original_env = {
|
cls.original_env = {
|
||||||
'OS_URL': os.getenv('OS_URL', 'http://localhost:9200'),
|
"OS_URL": os.getenv("OS_URL", "http://localhost:9200"),
|
||||||
'OS_USERNAME': os.getenv('OS_USERNAME', 'test_user'),
|
"OS_USERNAME": os.getenv("OS_USERNAME", "test_user"),
|
||||||
'OS_PASSWORD': os.getenv('OS_PASSWORD', 'test_password')
|
"OS_PASSWORD": os.getenv("OS_PASSWORD", "test_password"),
|
||||||
}
|
}
|
||||||
os.environ['OS_URL'] = 'http://localhost'
|
os.environ["OS_URL"] = "http://localhost"
|
||||||
os.environ['OS_USERNAME'] = 'test_user'
|
os.environ["OS_USERNAME"] = "test_user"
|
||||||
os.environ['OS_PASSWORD'] = 'test_password'
|
os.environ["OS_PASSWORD"] = "test_password"
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.client_mock = MagicMock(spec=OpenSearch)
|
self.client_mock = MagicMock(spec=OpenSearch)
|
||||||
@@ -40,19 +38,19 @@ class TestOpenSearchDB(unittest.TestCase):
|
|||||||
self.client_mock.delete = MagicMock()
|
self.client_mock.delete = MagicMock()
|
||||||
self.client_mock.search = MagicMock()
|
self.client_mock.search = MagicMock()
|
||||||
|
|
||||||
patcher = patch('mem0.vector_stores.opensearch.OpenSearch', return_value=self.client_mock)
|
patcher = patch("mem0.vector_stores.opensearch.OpenSearch", return_value=self.client_mock)
|
||||||
self.mock_os = patcher.start()
|
self.mock_os = patcher.start()
|
||||||
self.addCleanup(patcher.stop)
|
self.addCleanup(patcher.stop)
|
||||||
|
|
||||||
self.os_db = OpenSearchDB(
|
self.os_db = OpenSearchDB(
|
||||||
host=os.getenv('OS_URL'),
|
host=os.getenv("OS_URL"),
|
||||||
port=9200,
|
port=9200,
|
||||||
collection_name="test_collection",
|
collection_name="test_collection",
|
||||||
embedding_model_dims=1536,
|
embedding_model_dims=1536,
|
||||||
user=os.getenv('OS_USERNAME'),
|
user=os.getenv("OS_USERNAME"),
|
||||||
password=os.getenv('OS_PASSWORD'),
|
password=os.getenv("OS_PASSWORD"),
|
||||||
verify_certs=False,
|
verify_certs=False,
|
||||||
use_ssl=False
|
use_ssl=False,
|
||||||
)
|
)
|
||||||
self.client_mock.reset_mock()
|
self.client_mock.reset_mock()
|
||||||
|
|
||||||
@@ -86,29 +84,29 @@ class TestOpenSearchDB(unittest.TestCase):
|
|||||||
vectors = [[0.1] * 1536, [0.2] * 1536]
|
vectors = [[0.1] * 1536, [0.2] * 1536]
|
||||||
payloads = [{"key1": "value1"}, {"key2": "value2"}]
|
payloads = [{"key1": "value1"}, {"key2": "value2"}]
|
||||||
ids = ["id1", "id2"]
|
ids = ["id1", "id2"]
|
||||||
|
|
||||||
# Mock the index method
|
# Mock the index method
|
||||||
self.client_mock.index = MagicMock()
|
self.client_mock.index = MagicMock()
|
||||||
|
|
||||||
results = self.os_db.insert(vectors=vectors, payloads=payloads, ids=ids)
|
results = self.os_db.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||||
|
|
||||||
# Verify index was called twice (once for each vector)
|
# Verify index was called twice (once for each vector)
|
||||||
self.assertEqual(self.client_mock.index.call_count, 2)
|
self.assertEqual(self.client_mock.index.call_count, 2)
|
||||||
|
|
||||||
# Check first call
|
# Check first call
|
||||||
first_call = self.client_mock.index.call_args_list[0]
|
first_call = self.client_mock.index.call_args_list[0]
|
||||||
self.assertEqual(first_call[1]["index"], "test_collection")
|
self.assertEqual(first_call[1]["index"], "test_collection")
|
||||||
self.assertEqual(first_call[1]["body"]["vector_field"], vectors[0])
|
self.assertEqual(first_call[1]["body"]["vector_field"], vectors[0])
|
||||||
self.assertEqual(first_call[1]["body"]["payload"], payloads[0])
|
self.assertEqual(first_call[1]["body"]["payload"], payloads[0])
|
||||||
self.assertEqual(first_call[1]["body"]["id"], ids[0])
|
self.assertEqual(first_call[1]["body"]["id"], ids[0])
|
||||||
|
|
||||||
# Check second call
|
# Check second call
|
||||||
second_call = self.client_mock.index.call_args_list[1]
|
second_call = self.client_mock.index.call_args_list[1]
|
||||||
self.assertEqual(second_call[1]["index"], "test_collection")
|
self.assertEqual(second_call[1]["index"], "test_collection")
|
||||||
self.assertEqual(second_call[1]["body"]["vector_field"], vectors[1])
|
self.assertEqual(second_call[1]["body"]["vector_field"], vectors[1])
|
||||||
self.assertEqual(second_call[1]["body"]["payload"], payloads[1])
|
self.assertEqual(second_call[1]["body"]["payload"], payloads[1])
|
||||||
self.assertEqual(second_call[1]["body"]["id"], ids[1])
|
self.assertEqual(second_call[1]["body"]["id"], ids[1])
|
||||||
|
|
||||||
# Check results
|
# Check results
|
||||||
self.assertEqual(len(results), 2)
|
self.assertEqual(len(results), 2)
|
||||||
self.assertEqual(results[0].id, "id1")
|
self.assertEqual(results[0].id, "id1")
|
||||||
@@ -132,7 +130,7 @@ class TestOpenSearchDB(unittest.TestCase):
|
|||||||
self.client_mock.search.return_value = {"hits": {"hits": []}}
|
self.client_mock.search.return_value = {"hits": {"hits": []}}
|
||||||
result = self.os_db.get("nonexistent")
|
result = self.os_db.get("nonexistent")
|
||||||
self.assertIsNone(result)
|
self.assertIsNone(result)
|
||||||
|
|
||||||
def test_update(self):
|
def test_update(self):
|
||||||
vector = [0.3] * 1536
|
vector = [0.3] * 1536
|
||||||
payload = {"key3": "value3"}
|
payload = {"key3": "value3"}
|
||||||
@@ -152,7 +150,17 @@ class TestOpenSearchDB(unittest.TestCase):
|
|||||||
self.assertEqual(result, ["test_collection"])
|
self.assertEqual(result, ["test_collection"])
|
||||||
|
|
||||||
def test_search(self):
|
def test_search(self):
|
||||||
mock_response = {"hits": {"hits": [{"_id": "id1", "_score": 0.8, "_source": {"vector_field": [0.1] * 1536, "id": "id1", "payload": {"key1": "value1"}}}]}}
|
mock_response = {
|
||||||
|
"hits": {
|
||||||
|
"hits": [
|
||||||
|
{
|
||||||
|
"_id": "id1",
|
||||||
|
"_score": 0.8,
|
||||||
|
"_source": {"vector_field": [0.1] * 1536, "id": "id1", "payload": {"key1": "value1"}},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
self.client_mock.search.return_value = mock_response
|
self.client_mock.search.return_value = mock_response
|
||||||
vectors = [[0.1] * 1536]
|
vectors = [[0.1] * 1536]
|
||||||
results = self.os_db.search(query="", vectors=vectors, limit=5)
|
results = self.os_db.search(query="", vectors=vectors, limit=5)
|
||||||
@@ -179,12 +187,11 @@ class TestOpenSearchDB(unittest.TestCase):
|
|||||||
self.os_db.delete_col()
|
self.os_db.delete_col()
|
||||||
self.client_mock.indices.delete.assert_called_once_with(index="test_collection")
|
self.client_mock.indices.delete.assert_called_once_with(index="test_collection")
|
||||||
|
|
||||||
|
|
||||||
def test_init_with_http_auth(self):
|
def test_init_with_http_auth(self):
|
||||||
mock_credentials = MagicMock()
|
mock_credentials = MagicMock()
|
||||||
mock_signer = AWSV4SignerAuth(mock_credentials, "us-east-1", "es")
|
mock_signer = AWSV4SignerAuth(mock_credentials, "us-east-1", "es")
|
||||||
|
|
||||||
with patch('mem0.vector_stores.opensearch.OpenSearch') as mock_opensearch:
|
with patch("mem0.vector_stores.opensearch.OpenSearch") as mock_opensearch:
|
||||||
OpenSearchDB(
|
OpenSearchDB(
|
||||||
host="localhost",
|
host="localhost",
|
||||||
port=9200,
|
port=9200,
|
||||||
@@ -192,7 +199,7 @@ class TestOpenSearchDB(unittest.TestCase):
|
|||||||
embedding_model_dims=1536,
|
embedding_model_dims=1536,
|
||||||
http_auth=mock_signer,
|
http_auth=mock_signer,
|
||||||
verify_certs=True,
|
verify_certs=True,
|
||||||
use_ssl=True
|
use_ssl=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify OpenSearch was initialized with correct params
|
# Verify OpenSearch was initialized with correct params
|
||||||
@@ -202,5 +209,5 @@ class TestOpenSearchDB(unittest.TestCase):
|
|||||||
use_ssl=True,
|
use_ssl=True,
|
||||||
verify_certs=True,
|
verify_certs=True,
|
||||||
connection_class=unittest.mock.ANY,
|
connection_class=unittest.mock.ANY,
|
||||||
pool_maxsize=20
|
pool_maxsize=20,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ def mock_pinecone_client():
|
|||||||
client.list_indexes.return_value.names.return_value = []
|
client.list_indexes.return_value.names.return_value = []
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def pinecone_db(mock_pinecone_client):
|
def pinecone_db(mock_pinecone_client):
|
||||||
return PineconeDB(
|
return PineconeDB(
|
||||||
@@ -25,13 +26,14 @@ def pinecone_db(mock_pinecone_client):
|
|||||||
hybrid_search=False,
|
hybrid_search=False,
|
||||||
metric="cosine",
|
metric="cosine",
|
||||||
batch_size=100,
|
batch_size=100,
|
||||||
extra_params=None
|
extra_params=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_create_col_existing_index(mock_pinecone_client):
|
def test_create_col_existing_index(mock_pinecone_client):
|
||||||
# Set up the mock before creating the PineconeDB object
|
# Set up the mock before creating the PineconeDB object
|
||||||
mock_pinecone_client.list_indexes.return_value.names.return_value = ["test_index"]
|
mock_pinecone_client.list_indexes.return_value.names.return_value = ["test_index"]
|
||||||
|
|
||||||
pinecone_db = PineconeDB(
|
pinecone_db = PineconeDB(
|
||||||
collection_name="test_index",
|
collection_name="test_index",
|
||||||
embedding_model_dims=128,
|
embedding_model_dims=128,
|
||||||
@@ -43,21 +45,23 @@ def test_create_col_existing_index(mock_pinecone_client):
|
|||||||
hybrid_search=False,
|
hybrid_search=False,
|
||||||
metric="cosine",
|
metric="cosine",
|
||||||
batch_size=100,
|
batch_size=100,
|
||||||
extra_params=None
|
extra_params=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reset the mock to verify it wasn't called during the test
|
# Reset the mock to verify it wasn't called during the test
|
||||||
mock_pinecone_client.create_index.reset_mock()
|
mock_pinecone_client.create_index.reset_mock()
|
||||||
|
|
||||||
pinecone_db.create_col(128, "cosine")
|
pinecone_db.create_col(128, "cosine")
|
||||||
|
|
||||||
mock_pinecone_client.create_index.assert_not_called()
|
mock_pinecone_client.create_index.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
def test_create_col_new_index(pinecone_db, mock_pinecone_client):
|
def test_create_col_new_index(pinecone_db, mock_pinecone_client):
|
||||||
mock_pinecone_client.list_indexes.return_value.names.return_value = []
|
mock_pinecone_client.list_indexes.return_value.names.return_value = []
|
||||||
pinecone_db.create_col(128, "cosine")
|
pinecone_db.create_col(128, "cosine")
|
||||||
mock_pinecone_client.create_index.assert_called()
|
mock_pinecone_client.create_index.assert_called()
|
||||||
|
|
||||||
|
|
||||||
def test_insert_vectors(pinecone_db):
|
def test_insert_vectors(pinecone_db):
|
||||||
vectors = [[0.1] * 128, [0.2] * 128]
|
vectors = [[0.1] * 128, [0.2] * 128]
|
||||||
payloads = [{"name": "vector1"}, {"name": "vector2"}]
|
payloads = [{"name": "vector1"}, {"name": "vector2"}]
|
||||||
@@ -65,56 +69,61 @@ def test_insert_vectors(pinecone_db):
|
|||||||
pinecone_db.insert(vectors, payloads, ids)
|
pinecone_db.insert(vectors, payloads, ids)
|
||||||
pinecone_db.index.upsert.assert_called()
|
pinecone_db.index.upsert.assert_called()
|
||||||
|
|
||||||
|
|
||||||
def test_search_vectors(pinecone_db):
|
def test_search_vectors(pinecone_db):
|
||||||
pinecone_db.index.query.return_value.matches = [{"id": "id1", "score": 0.9, "metadata": {"name": "vector1"}}]
|
pinecone_db.index.query.return_value.matches = [{"id": "id1", "score": 0.9, "metadata": {"name": "vector1"}}]
|
||||||
results = pinecone_db.search("test query",[0.1] * 128, limit=1)
|
results = pinecone_db.search("test query", [0.1] * 128, limit=1)
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
assert results[0].id == "id1"
|
assert results[0].id == "id1"
|
||||||
assert results[0].score == 0.9
|
assert results[0].score == 0.9
|
||||||
|
|
||||||
|
|
||||||
def test_update_vector(pinecone_db):
|
def test_update_vector(pinecone_db):
|
||||||
pinecone_db.update("id1", vector=[0.5] * 128, payload={"name": "updated"})
|
pinecone_db.update("id1", vector=[0.5] * 128, payload={"name": "updated"})
|
||||||
pinecone_db.index.upsert.assert_called()
|
pinecone_db.index.upsert.assert_called()
|
||||||
|
|
||||||
|
|
||||||
def test_get_vector_found(pinecone_db):
|
def test_get_vector_found(pinecone_db):
|
||||||
# Looking at the _parse_output method, it expects a Vector object
|
# Looking at the _parse_output method, it expects a Vector object
|
||||||
# or a list of dictionaries, not a dictionary with an 'id' field
|
# or a list of dictionaries, not a dictionary with an 'id' field
|
||||||
|
|
||||||
# Create a mock Vector object
|
# Create a mock Vector object
|
||||||
from pinecone.data.dataclasses.vector import Vector
|
from pinecone.data.dataclasses.vector import Vector
|
||||||
mock_vector = Vector(
|
|
||||||
id="id1",
|
mock_vector = Vector(id="id1", values=[0.1] * 128, metadata={"name": "vector1"})
|
||||||
values=[0.1] * 128,
|
|
||||||
metadata={"name": "vector1"}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock the fetch method to return the mock response object
|
# Mock the fetch method to return the mock response object
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.vectors = {"id1": mock_vector}
|
mock_response.vectors = {"id1": mock_vector}
|
||||||
pinecone_db.index.fetch.return_value = mock_response
|
pinecone_db.index.fetch.return_value = mock_response
|
||||||
|
|
||||||
result = pinecone_db.get("id1")
|
result = pinecone_db.get("id1")
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.id == "id1"
|
assert result.id == "id1"
|
||||||
assert result.payload == {"name": "vector1"}
|
assert result.payload == {"name": "vector1"}
|
||||||
|
|
||||||
|
|
||||||
def test_delete_vector(pinecone_db):
|
def test_delete_vector(pinecone_db):
|
||||||
pinecone_db.delete("id1")
|
pinecone_db.delete("id1")
|
||||||
pinecone_db.index.delete.assert_called_with(ids=["id1"])
|
pinecone_db.index.delete.assert_called_with(ids=["id1"])
|
||||||
|
|
||||||
|
|
||||||
def test_get_vector_not_found(pinecone_db):
|
def test_get_vector_not_found(pinecone_db):
|
||||||
pinecone_db.index.fetch.return_value.vectors = {}
|
pinecone_db.index.fetch.return_value.vectors = {}
|
||||||
result = pinecone_db.get("id1")
|
result = pinecone_db.get("id1")
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
def test_list_cols(pinecone_db):
|
def test_list_cols(pinecone_db):
|
||||||
pinecone_db.list_cols()
|
pinecone_db.list_cols()
|
||||||
pinecone_db.client.list_indexes.assert_called()
|
pinecone_db.client.list_indexes.assert_called()
|
||||||
|
|
||||||
|
|
||||||
def test_delete_col(pinecone_db):
|
def test_delete_col(pinecone_db):
|
||||||
pinecone_db.delete_col()
|
pinecone_db.delete_col()
|
||||||
pinecone_db.client.delete_index.assert_called_with("test_index")
|
pinecone_db.client.delete_index.assert_called_with("test_index")
|
||||||
|
|
||||||
|
|
||||||
def test_col_info(pinecone_db):
|
def test_col_info(pinecone_db):
|
||||||
pinecone_db.col_info()
|
pinecone_db.col_info()
|
||||||
pinecone_db.client.describe_index.assert_called_with("test_index")
|
pinecone_db.client.describe_index.assert_called_with("test_index")
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ def supabase_instance(mock_vecs_client, mock_collection):
|
|||||||
index_method=IndexMethod.HNSW,
|
index_method=IndexMethod.HNSW,
|
||||||
index_measure=IndexMeasure.COSINE,
|
index_measure=IndexMeasure.COSINE,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Manually set the collection attribute since we're mocking the initialization
|
# Manually set the collection attribute since we're mocking the initialization
|
||||||
instance.collection = mock_collection
|
instance.collection = mock_collection
|
||||||
return instance
|
return instance
|
||||||
@@ -46,14 +46,8 @@ def supabase_instance(mock_vecs_client, mock_collection):
|
|||||||
def test_create_col(supabase_instance, mock_vecs_client, mock_collection):
|
def test_create_col(supabase_instance, mock_vecs_client, mock_collection):
|
||||||
supabase_instance.create_col(1536)
|
supabase_instance.create_col(1536)
|
||||||
|
|
||||||
mock_vecs_client.return_value.get_or_create_collection.assert_called_with(
|
mock_vecs_client.return_value.get_or_create_collection.assert_called_with(name="test_collection", dimension=1536)
|
||||||
name="test_collection",
|
mock_collection.create_index.assert_called_with(method="hnsw", measure="cosine_distance")
|
||||||
dimension=1536
|
|
||||||
)
|
|
||||||
mock_collection.create_index.assert_called_with(
|
|
||||||
method="hnsw",
|
|
||||||
measure="cosine_distance"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_insert_vectors(supabase_instance, mock_collection):
|
def test_insert_vectors(supabase_instance, mock_collection):
|
||||||
@@ -63,18 +57,12 @@ def test_insert_vectors(supabase_instance, mock_collection):
|
|||||||
|
|
||||||
supabase_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
|
supabase_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||||
|
|
||||||
expected_records = [
|
expected_records = [("id1", [0.1, 0.2, 0.3], {"name": "vector1"}), ("id2", [0.4, 0.5, 0.6], {"name": "vector2"})]
|
||||||
("id1", [0.1, 0.2, 0.3], {"name": "vector1"}),
|
|
||||||
("id2", [0.4, 0.5, 0.6], {"name": "vector2"})
|
|
||||||
]
|
|
||||||
mock_collection.upsert.assert_called_once_with(expected_records)
|
mock_collection.upsert.assert_called_once_with(expected_records)
|
||||||
|
|
||||||
|
|
||||||
def test_search_vectors(supabase_instance, mock_collection):
|
def test_search_vectors(supabase_instance, mock_collection):
|
||||||
mock_results = [
|
mock_results = [("id1", 0.9, {"name": "vector1"}), ("id2", 0.8, {"name": "vector2"})]
|
||||||
("id1", 0.9, {"name": "vector1"}),
|
|
||||||
("id2", 0.8, {"name": "vector2"})
|
|
||||||
]
|
|
||||||
mock_collection.query.return_value = mock_results
|
mock_collection.query.return_value = mock_results
|
||||||
|
|
||||||
vectors = [[0.1, 0.2, 0.3]]
|
vectors = [[0.1, 0.2, 0.3]]
|
||||||
@@ -82,11 +70,7 @@ def test_search_vectors(supabase_instance, mock_collection):
|
|||||||
results = supabase_instance.search(query="", vectors=vectors, limit=2, filters=filters)
|
results = supabase_instance.search(query="", vectors=vectors, limit=2, filters=filters)
|
||||||
|
|
||||||
mock_collection.query.assert_called_once_with(
|
mock_collection.query.assert_called_once_with(
|
||||||
data=vectors,
|
data=vectors, limit=2, filters={"category": {"$eq": "test"}}, include_metadata=True, include_value=True
|
||||||
limit=2,
|
|
||||||
filters={"category": {"$eq": "test"}},
|
|
||||||
include_metadata=True,
|
|
||||||
include_value=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(results) == 2
|
assert len(results) == 2
|
||||||
@@ -129,11 +113,8 @@ def test_get_vector(supabase_instance, mock_collection):
|
|||||||
|
|
||||||
def test_list_vectors(supabase_instance, mock_collection):
|
def test_list_vectors(supabase_instance, mock_collection):
|
||||||
mock_query_results = [("id1", 0.9, {}), ("id2", 0.8, {})]
|
mock_query_results = [("id1", 0.9, {}), ("id2", 0.8, {})]
|
||||||
mock_fetch_results = [
|
mock_fetch_results = [("id1", [0.1, 0.2, 0.3], {"name": "vector1"}), ("id2", [0.4, 0.5, 0.6], {"name": "vector2"})]
|
||||||
("id1", [0.1, 0.2, 0.3], {"name": "vector1"}),
|
|
||||||
("id2", [0.4, 0.5, 0.6], {"name": "vector2"})
|
|
||||||
]
|
|
||||||
|
|
||||||
mock_collection.query.return_value = mock_query_results
|
mock_collection.query.return_value = mock_query_results
|
||||||
mock_collection.fetch.return_value = mock_fetch_results
|
mock_collection.fetch.return_value = mock_fetch_results
|
||||||
|
|
||||||
@@ -153,10 +134,7 @@ def test_col_info(supabase_instance, mock_collection):
|
|||||||
"name": "test_collection",
|
"name": "test_collection",
|
||||||
"count": 100,
|
"count": 100,
|
||||||
"dimension": 1536,
|
"dimension": 1536,
|
||||||
"index": {
|
"index": {"method": "hnsw", "metric": "cosine_distance"},
|
||||||
"method": "hnsw",
|
|
||||||
"metric": "cosine_distance"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -168,10 +146,7 @@ def test_preprocess_filters(supabase_instance):
|
|||||||
# Test multiple filters
|
# Test multiple filters
|
||||||
multi_filter = {"category": "test", "type": "document"}
|
multi_filter = {"category": "test", "type": "document"}
|
||||||
assert supabase_instance._preprocess_filters(multi_filter) == {
|
assert supabase_instance._preprocess_filters(multi_filter) == {
|
||||||
"$and": [
|
"$and": [{"category": {"$eq": "test"}}, {"type": {"$eq": "document"}}]
|
||||||
{"category": {"$eq": "test"}},
|
|
||||||
{"type": {"$eq": "document"}}
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Test None filters
|
# Test None filters
|
||||||
|
|||||||
@@ -29,9 +29,7 @@ def upstash_instance(mock_index):
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def upstash_instance_with_embeddings(mock_index):
|
def upstash_instance_with_embeddings(mock_index):
|
||||||
return UpstashVector(
|
return UpstashVector(client=mock_index.return_value, collection_name="ns", enable_embeddings=True)
|
||||||
client=mock_index.return_value, collection_name="ns", enable_embeddings=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_insert_vectors(upstash_instance, mock_index):
|
def test_insert_vectors(upstash_instance, mock_index):
|
||||||
@@ -52,12 +50,8 @@ def test_insert_vectors(upstash_instance, mock_index):
|
|||||||
|
|
||||||
def test_search_vectors(upstash_instance, mock_index):
|
def test_search_vectors(upstash_instance, mock_index):
|
||||||
mock_result = [
|
mock_result = [
|
||||||
QueryResult(
|
QueryResult(id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None),
|
||||||
id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None
|
QueryResult(id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data=None),
|
||||||
),
|
|
||||||
QueryResult(
|
|
||||||
id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data=None
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
upstash_instance.client.query_many.return_value = [mock_result]
|
upstash_instance.client.query_many.return_value = [mock_result]
|
||||||
@@ -93,9 +87,7 @@ def test_delete_vector(upstash_instance):
|
|||||||
|
|
||||||
upstash_instance.delete(vector_id=vector_id)
|
upstash_instance.delete(vector_id=vector_id)
|
||||||
|
|
||||||
upstash_instance.client.delete.assert_called_once_with(
|
upstash_instance.client.delete.assert_called_once_with(ids=[vector_id], namespace="ns")
|
||||||
ids=[vector_id], namespace="ns"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_vector(upstash_instance):
|
def test_update_vector(upstash_instance):
|
||||||
@@ -115,18 +107,12 @@ def test_update_vector(upstash_instance):
|
|||||||
|
|
||||||
|
|
||||||
def test_get_vector(upstash_instance):
|
def test_get_vector(upstash_instance):
|
||||||
mock_result = [
|
mock_result = [QueryResult(id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None)]
|
||||||
QueryResult(
|
|
||||||
id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None
|
|
||||||
)
|
|
||||||
]
|
|
||||||
upstash_instance.client.fetch.return_value = mock_result
|
upstash_instance.client.fetch.return_value = mock_result
|
||||||
|
|
||||||
result = upstash_instance.get(vector_id="id1")
|
result = upstash_instance.get(vector_id="id1")
|
||||||
|
|
||||||
upstash_instance.client.fetch.assert_called_once_with(
|
upstash_instance.client.fetch.assert_called_once_with(ids=["id1"], namespace="ns", include_metadata=True)
|
||||||
ids=["id1"], namespace="ns", include_metadata=True
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.id == "id1"
|
assert result.id == "id1"
|
||||||
assert result.payload == {"name": "vector1"}
|
assert result.payload == {"name": "vector1"}
|
||||||
@@ -134,15 +120,9 @@ def test_get_vector(upstash_instance):
|
|||||||
|
|
||||||
def test_list_vectors(upstash_instance):
|
def test_list_vectors(upstash_instance):
|
||||||
mock_result = [
|
mock_result = [
|
||||||
QueryResult(
|
QueryResult(id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None),
|
||||||
id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None
|
QueryResult(id="id2", score=None, vector=None, metadata={"name": "vector2"}, data=None),
|
||||||
),
|
QueryResult(id="id3", score=None, vector=None, metadata={"name": "vector3"}, data=None),
|
||||||
QueryResult(
|
|
||||||
id="id2", score=None, vector=None, metadata={"name": "vector2"}, data=None
|
|
||||||
),
|
|
||||||
QueryResult(
|
|
||||||
id="id3", score=None, vector=None, metadata={"name": "vector3"}, data=None
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
handler = MagicMock()
|
handler = MagicMock()
|
||||||
|
|
||||||
@@ -204,12 +184,8 @@ def test_insert_vectors_with_embeddings(upstash_instance_with_embeddings, mock_i
|
|||||||
|
|
||||||
def test_search_vectors_with_embeddings(upstash_instance_with_embeddings, mock_index):
|
def test_search_vectors_with_embeddings(upstash_instance_with_embeddings, mock_index):
|
||||||
mock_result = [
|
mock_result = [
|
||||||
QueryResult(
|
QueryResult(id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data="data1"),
|
||||||
id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data="data1"
|
QueryResult(id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data="data2"),
|
||||||
),
|
|
||||||
QueryResult(
|
|
||||||
id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data="data2"
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
upstash_instance_with_embeddings.client.query.return_value = mock_result
|
upstash_instance_with_embeddings.client.query.return_value = mock_result
|
||||||
@@ -260,9 +236,7 @@ def test_insert_vectors_with_embeddings_missing_data(upstash_instance_with_embed
|
|||||||
ValueError,
|
ValueError,
|
||||||
match="When embeddings are enabled, all payloads must contain a 'data' field",
|
match="When embeddings are enabled, all payloads must contain a 'data' field",
|
||||||
):
|
):
|
||||||
upstash_instance_with_embeddings.insert(
|
upstash_instance_with_embeddings.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||||
vectors=vectors, payloads=payloads, ids=ids
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_vector_with_embeddings_missing_data(upstash_instance_with_embeddings):
|
def test_update_vector_with_embeddings_missing_data(upstash_instance_with_embeddings):
|
||||||
@@ -316,18 +290,12 @@ def test_get_vector_not_found(upstash_instance):
|
|||||||
|
|
||||||
result = upstash_instance.get(vector_id="nonexistent")
|
result = upstash_instance.get(vector_id="nonexistent")
|
||||||
|
|
||||||
upstash_instance.client.fetch.assert_called_once_with(
|
upstash_instance.client.fetch.assert_called_once_with(ids=["nonexistent"], namespace="ns", include_metadata=True)
|
||||||
ids=["nonexistent"], namespace="ns", include_metadata=True
|
|
||||||
)
|
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
def test_search_vectors_empty_filters(upstash_instance):
|
def test_search_vectors_empty_filters(upstash_instance):
|
||||||
mock_result = [
|
mock_result = [QueryResult(id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None)]
|
||||||
QueryResult(
|
|
||||||
id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None
|
|
||||||
)
|
|
||||||
]
|
|
||||||
upstash_instance.client.query_many.return_value = [mock_result]
|
upstash_instance.client.query_many.return_value = [mock_result]
|
||||||
|
|
||||||
vectors = [[0.1, 0.2, 0.3]]
|
vectors = [[0.1, 0.2, 0.3]]
|
||||||
|
|||||||
@@ -14,47 +14,50 @@ from mem0.vector_stores.vertex_ai_vector_search import GoogleMatchingEngine
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_vertex_ai():
|
def mock_vertex_ai():
|
||||||
with patch('google.cloud.aiplatform.MatchingEngineIndex') as mock_index, \
|
with (
|
||||||
patch('google.cloud.aiplatform.MatchingEngineIndexEndpoint') as mock_endpoint, \
|
patch("google.cloud.aiplatform.MatchingEngineIndex") as mock_index,
|
||||||
patch('google.cloud.aiplatform.init') as mock_init:
|
patch("google.cloud.aiplatform.MatchingEngineIndexEndpoint") as mock_endpoint,
|
||||||
|
patch("google.cloud.aiplatform.init") as mock_init,
|
||||||
|
):
|
||||||
mock_index_instance = Mock()
|
mock_index_instance = Mock()
|
||||||
mock_endpoint_instance = Mock()
|
mock_endpoint_instance = Mock()
|
||||||
yield {
|
yield {
|
||||||
'index': mock_index_instance,
|
"index": mock_index_instance,
|
||||||
'endpoint': mock_endpoint_instance,
|
"endpoint": mock_endpoint_instance,
|
||||||
'init': mock_init,
|
"init": mock_init,
|
||||||
'index_class': mock_index,
|
"index_class": mock_index,
|
||||||
'endpoint_class': mock_endpoint
|
"endpoint_class": mock_endpoint,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def config():
|
def config():
|
||||||
return GoogleMatchingEngineConfig(
|
return GoogleMatchingEngineConfig(
|
||||||
project_id='test-project',
|
project_id="test-project",
|
||||||
project_number='123456789',
|
project_number="123456789",
|
||||||
region='us-central1',
|
region="us-central1",
|
||||||
endpoint_id='test-endpoint',
|
endpoint_id="test-endpoint",
|
||||||
index_id='test-index',
|
index_id="test-index",
|
||||||
deployment_index_id='test-deployment',
|
deployment_index_id="test-deployment",
|
||||||
collection_name='test-collection',
|
collection_name="test-collection",
|
||||||
vector_search_api_endpoint='test.vertexai.goog'
|
vector_search_api_endpoint="test.vertexai.goog",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def vector_store(config, mock_vertex_ai):
|
def vector_store(config, mock_vertex_ai):
|
||||||
mock_vertex_ai['index_class'].return_value = mock_vertex_ai['index']
|
mock_vertex_ai["index_class"].return_value = mock_vertex_ai["index"]
|
||||||
mock_vertex_ai['endpoint_class'].return_value = mock_vertex_ai['endpoint']
|
mock_vertex_ai["endpoint_class"].return_value = mock_vertex_ai["endpoint"]
|
||||||
return GoogleMatchingEngine(**config.model_dump())
|
return GoogleMatchingEngine(**config.model_dump())
|
||||||
|
|
||||||
|
|
||||||
def test_initialization(vector_store, mock_vertex_ai, config):
|
def test_initialization(vector_store, mock_vertex_ai, config):
|
||||||
"""Test proper initialization of GoogleMatchingEngine"""
|
"""Test proper initialization of GoogleMatchingEngine"""
|
||||||
mock_vertex_ai['init'].assert_called_once_with(
|
mock_vertex_ai["init"].assert_called_once_with(project=config.project_id, location=config.region)
|
||||||
project=config.project_id,
|
|
||||||
location=config.region
|
|
||||||
)
|
|
||||||
|
|
||||||
expected_index_path = f"projects/{config.project_number}/locations/{config.region}/indexes/{config.index_id}"
|
expected_index_path = f"projects/{config.project_number}/locations/{config.region}/indexes/{config.index_id}"
|
||||||
mock_vertex_ai['index_class'].assert_called_once_with(index_name=expected_index_path)
|
mock_vertex_ai["index_class"].assert_called_once_with(index_name=expected_index_path)
|
||||||
|
|
||||||
|
|
||||||
def test_insert_vectors(vector_store, mock_vertex_ai):
|
def test_insert_vectors(vector_store, mock_vertex_ai):
|
||||||
"""Test inserting vectors with payloads"""
|
"""Test inserting vectors with payloads"""
|
||||||
@@ -64,13 +67,14 @@ def test_insert_vectors(vector_store, mock_vertex_ai):
|
|||||||
|
|
||||||
vector_store.insert(vectors=vectors, payloads=payloads, ids=ids)
|
vector_store.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||||
|
|
||||||
mock_vertex_ai['index'].upsert_datapoints.assert_called_once()
|
mock_vertex_ai["index"].upsert_datapoints.assert_called_once()
|
||||||
call_args = mock_vertex_ai['index'].upsert_datapoints.call_args[1]
|
call_args = mock_vertex_ai["index"].upsert_datapoints.call_args[1]
|
||||||
assert len(call_args['datapoints']) == 1
|
assert len(call_args["datapoints"]) == 1
|
||||||
datapoint_str = str(call_args['datapoints'][0])
|
datapoint_str = str(call_args["datapoints"][0])
|
||||||
assert "test-id" in datapoint_str
|
assert "test-id" in datapoint_str
|
||||||
assert "0.1" in datapoint_str and "0.2" in datapoint_str and "0.3" in datapoint_str
|
assert "0.1" in datapoint_str and "0.2" in datapoint_str and "0.3" in datapoint_str
|
||||||
|
|
||||||
|
|
||||||
def test_search_vectors(vector_store, mock_vertex_ai):
|
def test_search_vectors(vector_store, mock_vertex_ai):
|
||||||
"""Test searching vectors with filters"""
|
"""Test searching vectors with filters"""
|
||||||
vectors = [[0.1, 0.2, 0.3]]
|
vectors = [[0.1, 0.2, 0.3]]
|
||||||
@@ -85,7 +89,7 @@ def test_search_vectors(vector_store, mock_vertex_ai):
|
|||||||
mock_restrict.allow_list = ["test_user"]
|
mock_restrict.allow_list = ["test_user"]
|
||||||
mock_restrict.name = "user_id"
|
mock_restrict.name = "user_id"
|
||||||
mock_restrict.allow_tokens = ["test_user"]
|
mock_restrict.allow_tokens = ["test_user"]
|
||||||
|
|
||||||
mock_datapoint.restricts = [mock_restrict]
|
mock_datapoint.restricts = [mock_restrict]
|
||||||
|
|
||||||
mock_neighbor = Mock()
|
mock_neighbor = Mock()
|
||||||
@@ -94,16 +98,16 @@ def test_search_vectors(vector_store, mock_vertex_ai):
|
|||||||
mock_neighbor.datapoint = mock_datapoint
|
mock_neighbor.datapoint = mock_datapoint
|
||||||
mock_neighbor.restricts = [mock_restrict]
|
mock_neighbor.restricts = [mock_restrict]
|
||||||
|
|
||||||
mock_vertex_ai['endpoint'].find_neighbors.return_value = [[mock_neighbor]]
|
mock_vertex_ai["endpoint"].find_neighbors.return_value = [[mock_neighbor]]
|
||||||
|
|
||||||
results = vector_store.search(query="", vectors=vectors, filters=filters, limit=1)
|
results = vector_store.search(query="", vectors=vectors, filters=filters, limit=1)
|
||||||
|
|
||||||
mock_vertex_ai['endpoint'].find_neighbors.assert_called_once_with(
|
mock_vertex_ai["endpoint"].find_neighbors.assert_called_once_with(
|
||||||
deployed_index_id=vector_store.deployment_index_id,
|
deployed_index_id=vector_store.deployment_index_id,
|
||||||
queries=[vectors],
|
queries=[vectors],
|
||||||
num_neighbors=1,
|
num_neighbors=1,
|
||||||
filter=[Namespace("user_id", ["test_user"], [])],
|
filter=[Namespace("user_id", ["test_user"], [])],
|
||||||
return_full_datapoint=True
|
return_full_datapoint=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
@@ -111,29 +115,27 @@ def test_search_vectors(vector_store, mock_vertex_ai):
|
|||||||
assert results[0].score == 0.1
|
assert results[0].score == 0.1
|
||||||
assert results[0].payload == {"user_id": "test_user"}
|
assert results[0].payload == {"user_id": "test_user"}
|
||||||
|
|
||||||
|
|
||||||
def test_delete(vector_store, mock_vertex_ai):
|
def test_delete(vector_store, mock_vertex_ai):
|
||||||
"""Test deleting vectors"""
|
"""Test deleting vectors"""
|
||||||
vector_id = "test-id"
|
vector_id = "test-id"
|
||||||
|
|
||||||
remove_mock = Mock()
|
remove_mock = Mock()
|
||||||
|
|
||||||
with patch.object(GoogleMatchingEngine, 'delete', wraps=vector_store.delete) as delete_spy:
|
with patch.object(GoogleMatchingEngine, "delete", wraps=vector_store.delete) as delete_spy:
|
||||||
with patch.object(vector_store.index, 'remove_datapoints', remove_mock):
|
with patch.object(vector_store.index, "remove_datapoints", remove_mock):
|
||||||
vector_store.delete(ids=[vector_id])
|
vector_store.delete(ids=[vector_id])
|
||||||
|
|
||||||
delete_spy.assert_called_once_with(ids=[vector_id])
|
delete_spy.assert_called_once_with(ids=[vector_id])
|
||||||
remove_mock.assert_called_once_with(datapoint_ids=[vector_id])
|
remove_mock.assert_called_once_with(datapoint_ids=[vector_id])
|
||||||
|
|
||||||
|
|
||||||
def test_error_handling(vector_store, mock_vertex_ai):
|
def test_error_handling(vector_store, mock_vertex_ai):
|
||||||
"""Test error handling during operations"""
|
"""Test error handling during operations"""
|
||||||
mock_vertex_ai['index'].upsert_datapoints.side_effect = exceptions.InvalidArgument("Invalid request")
|
mock_vertex_ai["index"].upsert_datapoints.side_effect = exceptions.InvalidArgument("Invalid request")
|
||||||
|
|
||||||
with pytest.raises(Exception) as exc_info:
|
with pytest.raises(Exception) as exc_info:
|
||||||
vector_store.insert(
|
vector_store.insert(vectors=[[0.1, 0.2, 0.3]], payloads=[{"name": "test"}], ids=["test-id"])
|
||||||
vectors=[[0.1, 0.2, 0.3]],
|
|
||||||
payloads=[{"name": "test"}],
|
|
||||||
ids=["test-id"]
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(exc_info.value, exceptions.InvalidArgument)
|
assert isinstance(exc_info.value, exceptions.InvalidArgument)
|
||||||
assert "Invalid request" in str(exc_info.value)
|
assert "Invalid request" in str(exc_info.value)
|
||||||
|
|||||||
@@ -76,15 +76,15 @@
|
|||||||
# self.client_mock.batch = MagicMock()
|
# self.client_mock.batch = MagicMock()
|
||||||
|
|
||||||
# self.client_mock.batch.fixed_size.return_value.__enter__.return_value = MagicMock()
|
# self.client_mock.batch.fixed_size.return_value.__enter__.return_value = MagicMock()
|
||||||
|
|
||||||
# self.client_mock.collections.get.return_value.data.insert_many.return_value = {
|
# self.client_mock.collections.get.return_value.data.insert_many.return_value = {
|
||||||
# "results": [{"id": "id1"}, {"id": "id2"}]
|
# "results": [{"id": "id1"}, {"id": "id2"}]
|
||||||
# }
|
# }
|
||||||
|
|
||||||
# vectors = [[0.1] * 1536, [0.2] * 1536]
|
# vectors = [[0.1] * 1536, [0.2] * 1536]
|
||||||
# payloads = [{"key1": "value1"}, {"key2": "value2"}]
|
# payloads = [{"key1": "value1"}, {"key2": "value2"}]
|
||||||
# ids = [str(uuid.uuid4()), str(uuid.uuid4())]
|
# ids = [str(uuid.uuid4()), str(uuid.uuid4())]
|
||||||
|
|
||||||
# results = self.weaviate_db.insert(vectors=vectors, payloads=payloads, ids=ids)
|
# results = self.weaviate_db.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||||
|
|
||||||
# def test_get(self):
|
# def test_get(self):
|
||||||
@@ -108,7 +108,7 @@
|
|||||||
# result = self.weaviate_db.get(vector_id=valid_uuid)
|
# result = self.weaviate_db.get(vector_id=valid_uuid)
|
||||||
|
|
||||||
# assert result.id == valid_uuid
|
# assert result.id == valid_uuid
|
||||||
|
|
||||||
# expected_payload = mock_response.properties.copy()
|
# expected_payload = mock_response.properties.copy()
|
||||||
# expected_payload["id"] = valid_uuid
|
# expected_payload["id"] = valid_uuid
|
||||||
|
|
||||||
@@ -131,10 +131,10 @@
|
|||||||
# "metadata": {"distance": 0.2}
|
# "metadata": {"distance": 0.2}
|
||||||
# }
|
# }
|
||||||
# ]
|
# ]
|
||||||
|
|
||||||
# mock_response = MagicMock()
|
# mock_response = MagicMock()
|
||||||
# mock_response.objects = []
|
# mock_response.objects = []
|
||||||
|
|
||||||
# for obj in mock_objects:
|
# for obj in mock_objects:
|
||||||
# mock_obj = MagicMock()
|
# mock_obj = MagicMock()
|
||||||
# mock_obj.uuid = obj["uuid"]
|
# mock_obj.uuid = obj["uuid"]
|
||||||
@@ -142,16 +142,16 @@
|
|||||||
# mock_obj.metadata = MagicMock()
|
# mock_obj.metadata = MagicMock()
|
||||||
# mock_obj.metadata.distance = obj["metadata"]["distance"]
|
# mock_obj.metadata.distance = obj["metadata"]["distance"]
|
||||||
# mock_response.objects.append(mock_obj)
|
# mock_response.objects.append(mock_obj)
|
||||||
|
|
||||||
# mock_hybrid = MagicMock()
|
# mock_hybrid = MagicMock()
|
||||||
# self.client_mock.collections.get.return_value.query.hybrid = mock_hybrid
|
# self.client_mock.collections.get.return_value.query.hybrid = mock_hybrid
|
||||||
# mock_hybrid.return_value = mock_response
|
# mock_hybrid.return_value = mock_response
|
||||||
|
|
||||||
# vectors = [[0.1] * 1536]
|
# vectors = [[0.1] * 1536]
|
||||||
# results = self.weaviate_db.search(query="", vectors=vectors, limit=5)
|
# results = self.weaviate_db.search(query="", vectors=vectors, limit=5)
|
||||||
|
|
||||||
# mock_hybrid.assert_called_once()
|
# mock_hybrid.assert_called_once()
|
||||||
|
|
||||||
# self.assertEqual(len(results), 1)
|
# self.assertEqual(len(results), 1)
|
||||||
# self.assertEqual(results[0].id, "id1")
|
# self.assertEqual(results[0].id, "id1")
|
||||||
# self.assertEqual(results[0].score, 0.8)
|
# self.assertEqual(results[0].score, 0.8)
|
||||||
@@ -163,28 +163,28 @@
|
|||||||
|
|
||||||
# def test_list(self):
|
# def test_list(self):
|
||||||
# mock_objects = []
|
# mock_objects = []
|
||||||
|
|
||||||
# mock_obj1 = MagicMock()
|
# mock_obj1 = MagicMock()
|
||||||
# mock_obj1.uuid = "id1"
|
# mock_obj1.uuid = "id1"
|
||||||
# mock_obj1.properties = {"key1": "value1"}
|
# mock_obj1.properties = {"key1": "value1"}
|
||||||
# mock_objects.append(mock_obj1)
|
# mock_objects.append(mock_obj1)
|
||||||
|
|
||||||
# mock_obj2 = MagicMock()
|
# mock_obj2 = MagicMock()
|
||||||
# mock_obj2.uuid = "id2"
|
# mock_obj2.uuid = "id2"
|
||||||
# mock_obj2.properties = {"key2": "value2"}
|
# mock_obj2.properties = {"key2": "value2"}
|
||||||
# mock_objects.append(mock_obj2)
|
# mock_objects.append(mock_obj2)
|
||||||
|
|
||||||
# mock_response = MagicMock()
|
# mock_response = MagicMock()
|
||||||
# mock_response.objects = mock_objects
|
# mock_response.objects = mock_objects
|
||||||
|
|
||||||
# mock_fetch = MagicMock()
|
# mock_fetch = MagicMock()
|
||||||
# self.client_mock.collections.get.return_value.query.fetch_objects = mock_fetch
|
# self.client_mock.collections.get.return_value.query.fetch_objects = mock_fetch
|
||||||
# mock_fetch.return_value = mock_response
|
# mock_fetch.return_value = mock_response
|
||||||
|
|
||||||
# results = self.weaviate_db.list(limit=10)
|
# results = self.weaviate_db.list(limit=10)
|
||||||
|
|
||||||
# mock_fetch.assert_called_once()
|
# mock_fetch.assert_called_once()
|
||||||
|
|
||||||
# # Verify results
|
# # Verify results
|
||||||
# self.assertEqual(len(results), 1)
|
# self.assertEqual(len(results), 1)
|
||||||
# self.assertEqual(len(results[0]), 2)
|
# self.assertEqual(len(results[0]), 2)
|
||||||
|
|||||||
Reference in New Issue
Block a user