Formatting (#2750)

This commit is contained in:
Dev Khant
2025-05-22 01:17:29 +05:30
committed by GitHub
parent dff91154a7
commit d85fcda037
71 changed files with 1391 additions and 1823 deletions

View File

@@ -56,6 +56,7 @@ jobs:
run: | run: |
make install_all make install_all
pip install -e ".[test]" pip install -e ".[test]"
pip install pinecone pinecone-text
if: steps.cached-hatch-dependencies.outputs.cache-hit != 'true' if: steps.cached-hatch-dependencies.outputs.cache-hit != 'true'
- name: Run Formatting - name: Run Formatting
run: | run: |

View File

@@ -13,7 +13,7 @@
"import anthropic\n", "import anthropic\n",
"\n", "\n",
"# Set up environment variables\n", "# Set up environment variables\n",
"os.environ[\"OPENAI_API_KEY\"] = \"your_openai_api_key\" # needed for embedding model\n", "os.environ[\"OPENAI_API_KEY\"] = \"your_openai_api_key\" # needed for embedding model\n",
"os.environ[\"ANTHROPIC_API_KEY\"] = \"your_anthropic_api_key\"" "os.environ[\"ANTHROPIC_API_KEY\"] = \"your_anthropic_api_key\""
] ]
}, },
@@ -33,7 +33,7 @@
" \"model\": \"claude-3-5-sonnet-latest\",\n", " \"model\": \"claude-3-5-sonnet-latest\",\n",
" \"temperature\": 0.1,\n", " \"temperature\": 0.1,\n",
" \"max_tokens\": 2000,\n", " \"max_tokens\": 2000,\n",
" }\n", " },\n",
" }\n", " }\n",
" }\n", " }\n",
" self.client = anthropic.Client(api_key=os.environ[\"ANTHROPIC_API_KEY\"])\n", " self.client = anthropic.Client(api_key=os.environ[\"ANTHROPIC_API_KEY\"])\n",
@@ -50,11 +50,7 @@
" - Keep track of open issues and follow-ups\n", " - Keep track of open issues and follow-ups\n",
" \"\"\"\n", " \"\"\"\n",
"\n", "\n",
" def store_customer_interaction(self,\n", " def store_customer_interaction(self, user_id: str, message: str, response: str, metadata: Dict = None):\n",
" user_id: str,\n",
" message: str,\n",
" response: str,\n",
" metadata: Dict = None):\n",
" \"\"\"Store customer interaction in memory.\"\"\"\n", " \"\"\"Store customer interaction in memory.\"\"\"\n",
" if metadata is None:\n", " if metadata is None:\n",
" metadata = {}\n", " metadata = {}\n",
@@ -63,24 +59,17 @@
" metadata[\"timestamp\"] = datetime.now().isoformat()\n", " metadata[\"timestamp\"] = datetime.now().isoformat()\n",
"\n", "\n",
" # Format conversation for storage\n", " # Format conversation for storage\n",
" conversation = [\n", " conversation = [{\"role\": \"user\", \"content\": message}, {\"role\": \"assistant\", \"content\": response}]\n",
" {\"role\": \"user\", \"content\": message},\n",
" {\"role\": \"assistant\", \"content\": response}\n",
" ]\n",
"\n", "\n",
" # Store in Mem0\n", " # Store in Mem0\n",
" self.memory.add(\n", " self.memory.add(conversation, user_id=user_id, metadata=metadata)\n",
" conversation,\n",
" user_id=user_id,\n",
" metadata=metadata\n",
" )\n",
"\n", "\n",
" def get_relevant_history(self, user_id: str, query: str) -> List[Dict]:\n", " def get_relevant_history(self, user_id: str, query: str) -> List[Dict]:\n",
" \"\"\"Retrieve relevant past interactions.\"\"\"\n", " \"\"\"Retrieve relevant past interactions.\"\"\"\n",
" return self.memory.search(\n", " return self.memory.search(\n",
" query=query,\n", " query=query,\n",
" user_id=user_id,\n", " user_id=user_id,\n",
" limit=5 # Adjust based on needs\n", " limit=5, # Adjust based on needs\n",
" )\n", " )\n",
"\n", "\n",
" def handle_customer_query(self, user_id: str, query: str) -> str:\n", " def handle_customer_query(self, user_id: str, query: str) -> str:\n",
@@ -112,15 +101,12 @@
" model=\"claude-3-5-sonnet-latest\",\n", " model=\"claude-3-5-sonnet-latest\",\n",
" messages=[{\"role\": \"user\", \"content\": prompt}],\n", " messages=[{\"role\": \"user\", \"content\": prompt}],\n",
" max_tokens=2000,\n", " max_tokens=2000,\n",
" temperature=0.1\n", " temperature=0.1,\n",
" )\n", " )\n",
"\n", "\n",
" # Store interaction\n", " # Store interaction\n",
" self.store_customer_interaction(\n", " self.store_customer_interaction(\n",
" user_id=user_id,\n", " user_id=user_id, message=query, response=response, metadata={\"type\": \"support_query\"}\n",
" message=query,\n",
" response=response,\n",
" metadata={\"type\": \"support_query\"}\n",
" )\n", " )\n",
"\n", "\n",
" return response.content[0].text" " return response.content[0].text"
@@ -203,12 +189,12 @@
" # Get user input\n", " # Get user input\n",
" query = input()\n", " query = input()\n",
" print(\"Customer:\", query)\n", " print(\"Customer:\", query)\n",
" \n", "\n",
" # Check if user wants to exit\n", " # Check if user wants to exit\n",
" if query.lower() == 'exit':\n", " if query.lower() == \"exit\":\n",
" print(\"Thank you for using our support service. Goodbye!\")\n", " print(\"Thank you for using our support service. Goodbye!\")\n",
" break\n", " break\n",
" \n", "\n",
" # Handle the query and print the response\n", " # Handle the query and print the response\n",
" response = chatbot.handle_customer_query(user_id, query)\n", " response = chatbot.handle_customer_query(user_id, query)\n",
" print(\"Support:\", response, \"\\n\\n\")" " print(\"Support:\", response, \"\\n\\n\")"

View File

@@ -25,7 +25,8 @@
"source": [ "source": [
"# Set up ENV Vars\n", "# Set up ENV Vars\n",
"import os\n", "import os\n",
"os.environ['OPENAI_API_KEY'] = \"sk-xxx\"\n" "\n",
"os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\""
] ]
}, },
{ {
@@ -133,11 +134,9 @@
"assistant_id = os.environ.get(\"ASSISTANT_ID\", None)\n", "assistant_id = os.environ.get(\"ASSISTANT_ID\", None)\n",
"\n", "\n",
"# LLM Configuration\n", "# LLM Configuration\n",
"CACHE_SEED = 42 # choose your poison\n", "CACHE_SEED = 42 # choose your poison\n",
"llm_config = {\n", "llm_config = {\n",
" \"config_list\": [\n", " \"config_list\": [{\"model\": \"gpt-4o\", \"api_key\": os.environ[\"OPENAI_API_KEY\"]}],\n",
" {\"model\": \"gpt-4o\", \"api_key\": os.environ[\"OPENAI_API_KEY\"]}\n",
" ],\n",
" \"cache_seed\": CACHE_SEED,\n", " \"cache_seed\": CACHE_SEED,\n",
" \"timeout\": 120,\n", " \"timeout\": 120,\n",
" \"temperature\": 0.0,\n", " \"temperature\": 0.0,\n",
@@ -348,7 +347,7 @@
"source": [ "source": [
"# Retrieve the memory\n", "# Retrieve the memory\n",
"relevant_memories = MEM0_MEMORY_CLIENT.search(user_query, user_id=USER_ID, limit=3)\n", "relevant_memories = MEM0_MEMORY_CLIENT.search(user_query, user_id=USER_ID, limit=3)\n",
"relevant_memories_text = '\\n'.join(mem['memory'] for mem in relevant_memories)\n", "relevant_memories_text = \"\\n\".join(mem[\"memory\"] for mem in relevant_memories)\n",
"print(\"Relevant memories:\")\n", "print(\"Relevant memories:\")\n",
"print(relevant_memories_text)\n", "print(relevant_memories_text)\n",
"\n", "\n",
@@ -389,8 +388,8 @@
"# - Enables more context-aware and personalized agent responses.\n", "# - Enables more context-aware and personalized agent responses.\n",
"# - Bridges the gap between human input and AI processing in complex workflows.\n", "# - Bridges the gap between human input and AI processing in complex workflows.\n",
"\n", "\n",
"class Mem0ProxyCoderAgent(UserProxyAgent):\n",
"\n", "\n",
"class Mem0ProxyCoderAgent(UserProxyAgent):\n",
" def __init__(self, *args, **kwargs):\n", " def __init__(self, *args, **kwargs):\n",
" super().__init__(*args, **kwargs)\n", " super().__init__(*args, **kwargs)\n",
" self.memory = MEM0_MEMORY_CLIENT\n", " self.memory = MEM0_MEMORY_CLIENT\n",
@@ -399,15 +398,14 @@
" def initiate_chat(self, assistant, message):\n", " def initiate_chat(self, assistant, message):\n",
" # Retrieve memory for the agent\n", " # Retrieve memory for the agent\n",
" agent_memories = self.memory.search(message, agent_id=self.agent_id, limit=3)\n", " agent_memories = self.memory.search(message, agent_id=self.agent_id, limit=3)\n",
" agent_memories_txt = '\\n'.join(mem['memory'] for mem in agent_memories)\n", " agent_memories_txt = \"\\n\".join(mem[\"memory\"] for mem in agent_memories)\n",
" prompt = f\"{message}\\n Coding Preferences: \\n{str(agent_memories_txt)}\"\n", " prompt = f\"{message}\\n Coding Preferences: \\n{str(agent_memories_txt)}\"\n",
" response = super().initiate_chat(assistant, message=prompt)\n", " response = super().initiate_chat(assistant, message=prompt)\n",
" # Add new memory after processing the message\n", " # Add new memory after processing the message\n",
" response_dist = response.__dict__ if not isinstance(response, dict) else response\n", " response_dist = response.__dict__ if not isinstance(response, dict) else response\n",
" MEMORY_DATA = [{\"role\": \"user\", \"content\": message}, {\"role\": \"assistant\", \"content\": response_dist}]\n", " MEMORY_DATA = [{\"role\": \"user\", \"content\": message}, {\"role\": \"assistant\", \"content\": response_dist}]\n",
" self.memory.add(MEMORY_DATA, agent_id=self.agent_id)\n", " self.memory.add(MEMORY_DATA, agent_id=self.agent_id)\n",
" return response\n", " return response"
" "
] ]
}, },
{ {
@@ -560,12 +558,12 @@
"from cookbooks.helper.mem0_teachability import Mem0Teachability\n", "from cookbooks.helper.mem0_teachability import Mem0Teachability\n",
"\n", "\n",
"teachability = Mem0Teachability(\n", "teachability = Mem0Teachability(\n",
" verbosity=2, # for visibility of what's happening\n", " verbosity=2, # for visibility of what's happening\n",
" recall_threshold=0.5,\n", " recall_threshold=0.5,\n",
" reset_db=False, # Use True to force-reset the memo DB, and False to use an existing DB.\n", " reset_db=False, # Use True to force-reset the memo DB, and False to use an existing DB.\n",
" agent_id=AGENT_ID,\n", " agent_id=AGENT_ID,\n",
" memory_client = MEM0_MEMORY_CLIENT,\n", " memory_client=MEM0_MEMORY_CLIENT,\n",
" )\n", ")\n",
"teachability.add_to_agent(user_proxy)" "teachability.add_to_agent(user_proxy)"
] ]
}, },

View File

@@ -14,46 +14,47 @@ def process_item(item_data):
local_results = defaultdict(list) local_results = defaultdict(list)
for item in v: for item in v:
gt_answer = str(item['answer']) gt_answer = str(item["answer"])
pred_answer = str(item['response']) pred_answer = str(item["response"])
category = str(item['category']) category = str(item["category"])
question = str(item['question']) question = str(item["question"])
# Skip category 5 # Skip category 5
if category == '5': if category == "5":
continue continue
metrics = calculate_metrics(pred_answer, gt_answer) metrics = calculate_metrics(pred_answer, gt_answer)
bleu_scores = calculate_bleu_scores(pred_answer, gt_answer) bleu_scores = calculate_bleu_scores(pred_answer, gt_answer)
llm_score = evaluate_llm_judge(question, gt_answer, pred_answer) llm_score = evaluate_llm_judge(question, gt_answer, pred_answer)
local_results[k].append({ local_results[k].append(
"question": question, {
"answer": gt_answer, "question": question,
"response": pred_answer, "answer": gt_answer,
"category": category, "response": pred_answer,
"bleu_score": bleu_scores["bleu1"], "category": category,
"f1_score": metrics["f1"], "bleu_score": bleu_scores["bleu1"],
"llm_score": llm_score "f1_score": metrics["f1"],
}) "llm_score": llm_score,
}
)
return local_results return local_results
def main(): def main():
parser = argparse.ArgumentParser(description='Evaluate RAG results') parser = argparse.ArgumentParser(description="Evaluate RAG results")
parser.add_argument('--input_file', type=str, parser.add_argument(
default="results/rag_results_500_k1.json", "--input_file", type=str, default="results/rag_results_500_k1.json", help="Path to the input dataset file"
help='Path to the input dataset file') )
parser.add_argument('--output_file', type=str, parser.add_argument(
default="evaluation_metrics.json", "--output_file", type=str, default="evaluation_metrics.json", help="Path to save the evaluation results"
help='Path to save the evaluation results') )
parser.add_argument('--max_workers', type=int, default=10, parser.add_argument("--max_workers", type=int, default=10, help="Maximum number of worker threads")
help='Maximum number of worker threads')
args = parser.parse_args() args = parser.parse_args()
with open(args.input_file, 'r') as f: with open(args.input_file, "r") as f:
data = json.load(f) data = json.load(f)
results = defaultdict(list) results = defaultdict(list)
@@ -61,18 +62,16 @@ def main():
# Use ThreadPoolExecutor with specified workers # Use ThreadPoolExecutor with specified workers
with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor: with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor:
futures = [executor.submit(process_item, item_data) futures = [executor.submit(process_item, item_data) for item_data in data.items()]
for item_data in data.items()]
for future in tqdm(concurrent.futures.as_completed(futures), for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
total=len(futures)):
local_results = future.result() local_results = future.result()
with results_lock: with results_lock:
for k, items in local_results.items(): for k, items in local_results.items():
results[k].extend(items) results[k].extend(items)
# Save results to JSON file # Save results to JSON file
with open(args.output_file, 'w') as f: with open(args.output_file, "w") as f:
json.dump(results, f, indent=4) json.dump(results, f, indent=4)
print(f"Results saved to {args.output_file}") print(f"Results saved to {args.output_file}")

View File

@@ -3,7 +3,7 @@ import json
import pandas as pd import pandas as pd
# Load the evaluation metrics data # Load the evaluation metrics data
with open('evaluation_metrics.json', 'r') as f: with open("evaluation_metrics.json", "r") as f:
data = json.load(f) data = json.load(f)
# Flatten the data into a list of question items # Flatten the data into a list of question items
@@ -15,28 +15,20 @@ for key in data:
df = pd.DataFrame(all_items) df = pd.DataFrame(all_items)
# Convert category to numeric type # Convert category to numeric type
df['category'] = pd.to_numeric(df['category']) df["category"] = pd.to_numeric(df["category"])
# Calculate mean scores by category # Calculate mean scores by category
result = df.groupby('category').agg({ result = df.groupby("category").agg({"bleu_score": "mean", "f1_score": "mean", "llm_score": "mean"}).round(4)
'bleu_score': 'mean',
'f1_score': 'mean',
'llm_score': 'mean'
}).round(4)
# Add count of questions per category # Add count of questions per category
result['count'] = df.groupby('category').size() result["count"] = df.groupby("category").size()
# Print the results # Print the results
print("Mean Scores Per Category:") print("Mean Scores Per Category:")
print(result) print(result)
# Calculate overall means # Calculate overall means
overall_means = df.agg({ overall_means = df.agg({"bleu_score": "mean", "f1_score": "mean", "llm_score": "mean"}).round(4)
'bleu_score': 'mean',
'f1_score': 'mean',
'llm_score': 'mean'
}).round(4)
print("\nOverall Mean Scores:") print("\nOverall Mean Scores:")
print(overall_means) print(overall_means)

View File

@@ -33,35 +33,34 @@ Do NOT include both CORRECT and WRONG in your response, or it will break the eva
Just return the label CORRECT or WRONG in a json format with the key as "label". Just return the label CORRECT or WRONG in a json format with the key as "label".
""" """
def evaluate_llm_judge(question, gold_answer, generated_answer): def evaluate_llm_judge(question, gold_answer, generated_answer):
"""Evaluate the generated answer against the gold answer using an LLM judge.""" """Evaluate the generated answer against the gold answer using an LLM judge."""
response = client.chat.completions.create( response = client.chat.completions.create(
model="gpt-4o-mini", model="gpt-4o-mini",
messages=[{ messages=[
"role": "user", {
"content": ACCURACY_PROMPT.format( "role": "user",
question=question, "content": ACCURACY_PROMPT.format(
gold_answer=gold_answer, question=question, gold_answer=gold_answer, generated_answer=generated_answer
generated_answer=generated_answer ),
) }
}], ],
response_format={"type": "json_object"}, response_format={"type": "json_object"},
temperature=0.0 temperature=0.0,
) )
label = json.loads(response.choices[0].message.content)['label'] label = json.loads(response.choices[0].message.content)["label"]
return 1 if label == "CORRECT" else 0 return 1 if label == "CORRECT" else 0
def main(): def main():
"""Main function to evaluate RAG results using LLM judge.""" """Main function to evaluate RAG results using LLM judge."""
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Evaluate RAG results using LLM judge")
description='Evaluate RAG results using LLM judge'
)
parser.add_argument( parser.add_argument(
'--input_file', "--input_file",
type=str, type=str,
default="results/default_run_v4_k30_new_graph.json", default="results/default_run_v4_k30_new_graph.json",
help='Path to the input dataset file' help="Path to the input dataset file",
) )
args = parser.parse_args() args = parser.parse_args()
@@ -78,10 +77,10 @@ def main():
index = 0 index = 0
for k, v in data.items(): for k, v in data.items():
for x in v: for x in v:
question = x['question'] question = x["question"]
gold_answer = x['answer'] gold_answer = x["answer"]
generated_answer = x['response'] generated_answer = x["response"]
category = x['category'] category = x["category"]
# Skip category 5 # Skip category 5
if int(category) == 5: if int(category) == 5:
@@ -92,13 +91,15 @@ def main():
LLM_JUDGE[category].append(label) LLM_JUDGE[category].append(label)
# Store the results # Store the results
RESULTS[index].append({ RESULTS[index].append(
"question": question, {
"gt_answer": gold_answer, "question": question,
"response": generated_answer, "gt_answer": gold_answer,
"category": category, "response": generated_answer,
"llm_label": label "category": category,
}) "llm_label": label,
}
)
# Save intermediate results # Save intermediate results
with open(output_path, "w") as f: with open(output_path, "w") as f:
@@ -108,8 +109,7 @@ def main():
print("All categories accuracy:") print("All categories accuracy:")
for cat, results in LLM_JUDGE.items(): for cat, results in LLM_JUDGE.items():
if results: # Only print if there are results for this category if results: # Only print if there are results for this category
print(f" Category {cat}: {np.mean(results):.4f} " print(f" Category {cat}: {np.mean(results):.4f} " f"({sum(results)}/{len(results)})")
f"({sum(results)}/{len(results)})")
print("------------------------------------------") print("------------------------------------------")
index += 1 index += 1

View File

@@ -3,7 +3,7 @@ Borrowed from https://github.com/WujiangXu/AgenticMemory/blob/main/utils.py
@article{xu2025mem, @article{xu2025mem,
title={A-mem: Agentic memory for llm agents}, title={A-mem: Agentic memory for llm agents},
author={Xu, Wujiang and Liang, Zujie and Mei, Kai and Gao, Hang and Tan, Juntao author={Xu, Wujiang and Liang, Zujie and Mei, Kai and Gao, Hang and Tan, Juntao
and Zhang, Yongfeng}, and Zhang, Yongfeng},
journal={arXiv preprint arXiv:2502.12110}, journal={arXiv preprint arXiv:2502.12110},
year={2025} year={2025}
@@ -26,42 +26,45 @@ from sentence_transformers.util import pytorch_cos_sim
# Download required NLTK data # Download required NLTK data
try: try:
nltk.download('punkt', quiet=True) nltk.download("punkt", quiet=True)
nltk.download('wordnet', quiet=True) nltk.download("wordnet", quiet=True)
except Exception as e: except Exception as e:
print(f"Error downloading NLTK data: {e}") print(f"Error downloading NLTK data: {e}")
# Initialize SentenceTransformer model (this will be reused) # Initialize SentenceTransformer model (this will be reused)
try: try:
sentence_model = SentenceTransformer('all-MiniLM-L6-v2') sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
except Exception as e: except Exception as e:
print(f"Warning: Could not load SentenceTransformer model: {e}") print(f"Warning: Could not load SentenceTransformer model: {e}")
sentence_model = None sentence_model = None
def simple_tokenize(text): def simple_tokenize(text):
"""Simple tokenization function.""" """Simple tokenization function."""
# Convert to string if not already # Convert to string if not already
text = str(text) text = str(text)
return text.lower().replace('.', ' ').replace(',', ' ').replace('!', ' ').replace('?', ' ').split() return text.lower().replace(".", " ").replace(",", " ").replace("!", " ").replace("?", " ").split()
def calculate_rouge_scores(prediction: str, reference: str) -> Dict[str, float]: def calculate_rouge_scores(prediction: str, reference: str) -> Dict[str, float]:
"""Calculate ROUGE scores for prediction against reference.""" """Calculate ROUGE scores for prediction against reference."""
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True) scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
scores = scorer.score(reference, prediction) scores = scorer.score(reference, prediction)
return { return {
'rouge1_f': scores['rouge1'].fmeasure, "rouge1_f": scores["rouge1"].fmeasure,
'rouge2_f': scores['rouge2'].fmeasure, "rouge2_f": scores["rouge2"].fmeasure,
'rougeL_f': scores['rougeL'].fmeasure "rougeL_f": scores["rougeL"].fmeasure,
} }
def calculate_bleu_scores(prediction: str, reference: str) -> Dict[str, float]: def calculate_bleu_scores(prediction: str, reference: str) -> Dict[str, float]:
"""Calculate BLEU scores with different n-gram settings.""" """Calculate BLEU scores with different n-gram settings."""
pred_tokens = nltk.word_tokenize(prediction.lower()) pred_tokens = nltk.word_tokenize(prediction.lower())
ref_tokens = [nltk.word_tokenize(reference.lower())] ref_tokens = [nltk.word_tokenize(reference.lower())]
weights_list = [(1, 0, 0, 0), (0.5, 0.5, 0, 0), (0.33, 0.33, 0.33, 0), (0.25, 0.25, 0.25, 0.25)] weights_list = [(1, 0, 0, 0), (0.5, 0.5, 0, 0), (0.33, 0.33, 0.33, 0), (0.25, 0.25, 0.25, 0.25)]
smooth = SmoothingFunction().method1 smooth = SmoothingFunction().method1
scores = {} scores = {}
for n, weights in enumerate(weights_list, start=1): for n, weights in enumerate(weights_list, start=1):
try: try:
@@ -69,26 +72,20 @@ def calculate_bleu_scores(prediction: str, reference: str) -> Dict[str, float]:
except Exception as e: except Exception as e:
print(f"Error calculating BLEU score: {e}") print(f"Error calculating BLEU score: {e}")
score = 0.0 score = 0.0
scores[f'bleu{n}'] = score scores[f"bleu{n}"] = score
return scores return scores
def calculate_bert_scores(prediction: str, reference: str) -> Dict[str, float]: def calculate_bert_scores(prediction: str, reference: str) -> Dict[str, float]:
"""Calculate BERTScore for semantic similarity.""" """Calculate BERTScore for semantic similarity."""
try: try:
P, R, F1 = bert_score([prediction], [reference], lang='en', verbose=False) P, R, F1 = bert_score([prediction], [reference], lang="en", verbose=False)
return { return {"bert_precision": P.item(), "bert_recall": R.item(), "bert_f1": F1.item()}
'bert_precision': P.item(),
'bert_recall': R.item(),
'bert_f1': F1.item()
}
except Exception as e: except Exception as e:
print(f"Error calculating BERTScore: {e}") print(f"Error calculating BERTScore: {e}")
return { return {"bert_precision": 0.0, "bert_recall": 0.0, "bert_f1": 0.0}
'bert_precision': 0.0,
'bert_recall': 0.0,
'bert_f1': 0.0
}
def calculate_meteor_score(prediction: str, reference: str) -> float: def calculate_meteor_score(prediction: str, reference: str) -> float:
"""Calculate METEOR score for the prediction.""" """Calculate METEOR score for the prediction."""
@@ -98,6 +95,7 @@ def calculate_meteor_score(prediction: str, reference: str) -> float:
print(f"Error calculating METEOR score: {e}") print(f"Error calculating METEOR score: {e}")
return 0.0 return 0.0
def calculate_sentence_similarity(prediction: str, reference: str) -> float: def calculate_sentence_similarity(prediction: str, reference: str) -> float:
"""Calculate sentence embedding similarity using SentenceBERT.""" """Calculate sentence embedding similarity using SentenceBERT."""
if sentence_model is None: if sentence_model is None:
@@ -106,7 +104,7 @@ def calculate_sentence_similarity(prediction: str, reference: str) -> float:
# Encode sentences # Encode sentences
embedding1 = sentence_model.encode([prediction], convert_to_tensor=True) embedding1 = sentence_model.encode([prediction], convert_to_tensor=True)
embedding2 = sentence_model.encode([reference], convert_to_tensor=True) embedding2 = sentence_model.encode([reference], convert_to_tensor=True)
# Calculate cosine similarity # Calculate cosine similarity
similarity = pytorch_cos_sim(embedding1, embedding2).item() similarity = pytorch_cos_sim(embedding1, embedding2).item()
return float(similarity) return float(similarity)
@@ -114,6 +112,7 @@ def calculate_sentence_similarity(prediction: str, reference: str) -> float:
print(f"Error calculating sentence similarity: {e}") print(f"Error calculating sentence similarity: {e}")
return 0.0 return 0.0
def calculate_metrics(prediction: str, reference: str) -> Dict[str, float]: def calculate_metrics(prediction: str, reference: str) -> Dict[str, float]:
"""Calculate comprehensive evaluation metrics for a prediction.""" """Calculate comprehensive evaluation metrics for a prediction."""
# Handle empty or None values # Handle empty or None values
@@ -130,31 +129,31 @@ def calculate_metrics(prediction: str, reference: str) -> Dict[str, float]:
"bleu4": 0.0, "bleu4": 0.0,
"bert_f1": 0.0, "bert_f1": 0.0,
"meteor": 0.0, "meteor": 0.0,
"sbert_similarity": 0.0 "sbert_similarity": 0.0,
} }
# Convert to strings if they're not already # Convert to strings if they're not already
prediction = str(prediction).strip() prediction = str(prediction).strip()
reference = str(reference).strip() reference = str(reference).strip()
# Calculate exact match # Calculate exact match
exact_match = int(prediction.lower() == reference.lower()) exact_match = int(prediction.lower() == reference.lower())
# Calculate token-based F1 score # Calculate token-based F1 score
pred_tokens = set(simple_tokenize(prediction)) pred_tokens = set(simple_tokenize(prediction))
ref_tokens = set(simple_tokenize(reference)) ref_tokens = set(simple_tokenize(reference))
common_tokens = pred_tokens & ref_tokens common_tokens = pred_tokens & ref_tokens
if not pred_tokens or not ref_tokens: if not pred_tokens or not ref_tokens:
f1 = 0.0 f1 = 0.0
else: else:
precision = len(common_tokens) / len(pred_tokens) precision = len(common_tokens) / len(pred_tokens)
recall = len(common_tokens) / len(ref_tokens) recall = len(common_tokens) / len(ref_tokens)
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
# Calculate all scores # Calculate all scores
bleu_scores = calculate_bleu_scores(prediction, reference) bleu_scores = calculate_bleu_scores(prediction, reference)
# Combine all metrics # Combine all metrics
metrics = { metrics = {
"exact_match": exact_match, "exact_match": exact_match,
@@ -164,48 +163,49 @@ def calculate_metrics(prediction: str, reference: str) -> Dict[str, float]:
return metrics return metrics
def aggregate_metrics(all_metrics: List[Dict[str, float]], all_categories: List[int]) -> Dict[str, Dict[str, Union[float, Dict[str, float]]]]:
def aggregate_metrics(
all_metrics: List[Dict[str, float]], all_categories: List[int]
) -> Dict[str, Dict[str, Union[float, Dict[str, float]]]]:
"""Calculate aggregate statistics for all metrics, split by category.""" """Calculate aggregate statistics for all metrics, split by category."""
if not all_metrics: if not all_metrics:
return {} return {}
# Initialize aggregates for overall and per-category metrics # Initialize aggregates for overall and per-category metrics
aggregates = defaultdict(list) aggregates = defaultdict(list)
category_aggregates = defaultdict(lambda: defaultdict(list)) category_aggregates = defaultdict(lambda: defaultdict(list))
# Collect all values for each metric, both overall and per category # Collect all values for each metric, both overall and per category
for metrics, category in zip(all_metrics, all_categories): for metrics, category in zip(all_metrics, all_categories):
for metric_name, value in metrics.items(): for metric_name, value in metrics.items():
aggregates[metric_name].append(value) aggregates[metric_name].append(value)
category_aggregates[category][metric_name].append(value) category_aggregates[category][metric_name].append(value)
# Calculate statistics for overall metrics # Calculate statistics for overall metrics
results = { results = {"overall": {}}
"overall": {}
}
for metric_name, values in aggregates.items(): for metric_name, values in aggregates.items():
results["overall"][metric_name] = { results["overall"][metric_name] = {
'mean': statistics.mean(values), "mean": statistics.mean(values),
'std': statistics.stdev(values) if len(values) > 1 else 0.0, "std": statistics.stdev(values) if len(values) > 1 else 0.0,
'median': statistics.median(values), "median": statistics.median(values),
'min': min(values), "min": min(values),
'max': max(values), "max": max(values),
'count': len(values) "count": len(values),
} }
# Calculate statistics for each category # Calculate statistics for each category
for category in sorted(category_aggregates.keys()): for category in sorted(category_aggregates.keys()):
results[f"category_{category}"] = {} results[f"category_{category}"] = {}
for metric_name, values in category_aggregates[category].items(): for metric_name, values in category_aggregates[category].items():
if values: # Only calculate if we have values for this category if values: # Only calculate if we have values for this category
results[f"category_{category}"][metric_name] = { results[f"category_{category}"][metric_name] = {
'mean': statistics.mean(values), "mean": statistics.mean(values),
'std': statistics.stdev(values) if len(values) > 1 else 0.0, "std": statistics.stdev(values) if len(values) > 1 else 0.0,
'median': statistics.median(values), "median": statistics.median(values),
'min': min(values), "min": min(values),
'max': max(values), "max": max(values),
'count': len(values) "count": len(values),
} }
return results return results

View File

@@ -144,4 +144,4 @@ ANSWER_PROMPT_ZEP = """
Question: {{question}} Question: {{question}}
Answer: Answer:
""" """

View File

@@ -21,23 +21,15 @@ class Experiment:
def main(): def main():
parser = argparse.ArgumentParser(description='Run memory experiments') parser = argparse.ArgumentParser(description="Run memory experiments")
parser.add_argument('--technique_type', choices=TECHNIQUES, default='mem0', parser.add_argument("--technique_type", choices=TECHNIQUES, default="mem0", help="Memory technique to use")
help='Memory technique to use') parser.add_argument("--method", choices=METHODS, default="add", help="Method to use")
parser.add_argument('--method', choices=METHODS, default='add', parser.add_argument("--chunk_size", type=int, default=1000, help="Chunk size for processing")
help='Method to use') parser.add_argument("--output_folder", type=str, default="results/", help="Output path for results")
parser.add_argument('--chunk_size', type=int, default=1000, parser.add_argument("--top_k", type=int, default=30, help="Number of top memories to retrieve")
help='Chunk size for processing') parser.add_argument("--filter_memories", action="store_true", default=False, help="Whether to filter memories")
parser.add_argument('--output_folder', type=str, default='results/', parser.add_argument("--is_graph", action="store_true", default=False, help="Whether to use graph-based search")
help='Output path for results') parser.add_argument("--num_chunks", type=int, default=1, help="Number of chunks to process")
parser.add_argument('--top_k', type=int, default=30,
help='Number of top memories to retrieve')
parser.add_argument('--filter_memories', action='store_true', default=False,
help='Whether to filter memories')
parser.add_argument('--is_graph', action='store_true', default=False,
help='Whether to use graph-based search')
parser.add_argument('--num_chunks', type=int, default=1,
help='Number of chunks to process')
args = parser.parse_args() args = parser.parse_args()
@@ -46,33 +38,18 @@ def main():
if args.technique_type == "mem0": if args.technique_type == "mem0":
if args.method == "add": if args.method == "add":
memory_manager = MemoryADD( memory_manager = MemoryADD(data_path="dataset/locomo10.json", is_graph=args.is_graph)
data_path='dataset/locomo10.json',
is_graph=args.is_graph
)
memory_manager.process_all_conversations() memory_manager.process_all_conversations()
elif args.method == "search": elif args.method == "search":
output_file_path = os.path.join( output_file_path = os.path.join(
args.output_folder, args.output_folder,
f"mem0_results_top_{args.top_k}_filter_{args.filter_memories}_graph_{args.is_graph}.json" f"mem0_results_top_{args.top_k}_filter_{args.filter_memories}_graph_{args.is_graph}.json",
) )
memory_searcher = MemorySearch( memory_searcher = MemorySearch(output_file_path, args.top_k, args.filter_memories, args.is_graph)
output_file_path, memory_searcher.process_data_file("dataset/locomo10.json")
args.top_k,
args.filter_memories,
args.is_graph
)
memory_searcher.process_data_file('dataset/locomo10.json')
elif args.technique_type == "rag": elif args.technique_type == "rag":
output_file_path = os.path.join( output_file_path = os.path.join(args.output_folder, f"rag_results_{args.chunk_size}_k{args.num_chunks}.json")
args.output_folder, rag_manager = RAGManager(data_path="dataset/locomo10_rag.json", chunk_size=args.chunk_size, k=args.num_chunks)
f"rag_results_{args.chunk_size}_k{args.num_chunks}.json"
)
rag_manager = RAGManager(
data_path="dataset/locomo10_rag.json",
chunk_size=args.chunk_size,
k=args.num_chunks
)
rag_manager.process_all_conversations(output_file_path) rag_manager.process_all_conversations(output_file_path)
elif args.technique_type == "langmem": elif args.technique_type == "langmem":
output_file_path = os.path.join(args.output_folder, "langmem_results.json") output_file_path = os.path.join(args.output_folder, "langmem_results.json")
@@ -85,11 +62,7 @@ def main():
elif args.method == "search": elif args.method == "search":
output_file_path = os.path.join(args.output_folder, "zep_search_results.json") output_file_path = os.path.join(args.output_folder, "zep_search_results.json")
zep_manager = ZepSearch() zep_manager = ZepSearch()
zep_manager.process_data_file( zep_manager.process_data_file("dataset/locomo10.json", "1", output_file_path)
"dataset/locomo10.json",
"1",
output_file_path
)
elif args.technique_type == "openai": elif args.technique_type == "openai":
output_file_path = os.path.join(args.output_folder, "openai_results.json") output_file_path = os.path.join(args.output_folder, "openai_results.json")
openai_manager = OpenAIPredict() openai_manager = OpenAIPredict()

View File

@@ -28,14 +28,12 @@ def get_answer(question, speaker_1_user_id, speaker_1_memories, speaker_2_user_i
speaker_1_user_id=speaker_1_user_id, speaker_1_user_id=speaker_1_user_id,
speaker_1_memories=speaker_1_memories, speaker_1_memories=speaker_1_memories,
speaker_2_user_id=speaker_2_user_id, speaker_2_user_id=speaker_2_user_id,
speaker_2_memories=speaker_2_memories speaker_2_memories=speaker_2_memories,
) )
t1 = time.time() t1 = time.time()
response = client.chat.completions.create( response = client.chat.completions.create(
model=os.getenv("MODEL"), model=os.getenv("MODEL"), messages=[{"role": "system", "content": prompt}], temperature=0.0
messages=[{"role": "system", "content": prompt}],
temperature=0.0
) )
t2 = time.time() t2 = time.time()
return response.choices[0].message.content, t2 - t1 return response.choices[0].message.content, t2 - t1
@@ -59,7 +57,9 @@ def prompt(state):
class LangMem: class LangMem:
def __init__(self,): def __init__(
self,
):
self.store = InMemoryStore( self.store = InMemoryStore(
index={ index={
"dims": 1536, "dims": 1536,
@@ -80,18 +80,12 @@ class LangMem:
) )
def add_memory(self, message, config): def add_memory(self, message, config):
return self.agent.invoke( return self.agent.invoke({"messages": [{"role": "user", "content": message}]}, config=config)
{"messages": [{"role": "user", "content": message}]},
config=config
)
def search_memory(self, query, config): def search_memory(self, query, config):
try: try:
t1 = time.time() t1 = time.time()
response = self.agent.invoke( response = self.agent.invoke({"messages": [{"role": "user", "content": query}]}, config=config)
{"messages": [{"role": "user", "content": query}]},
config=config
)
t2 = time.time() t2 = time.time()
return response["messages"][-1].content, t2 - t1 return response["messages"][-1].content, t2 - t1
except Exception as e: except Exception as e:
@@ -102,7 +96,7 @@ class LangMem:
class LangMemManager: class LangMemManager:
def __init__(self, dataset_path): def __init__(self, dataset_path):
self.dataset_path = dataset_path self.dataset_path = dataset_path
with open(self.dataset_path, 'r') as f: with open(self.dataset_path, "r") as f:
self.data = json.load(f) self.data = json.load(f)
def process_all_conversations(self, output_file_path): def process_all_conversations(self, output_file_path):
@@ -123,7 +117,7 @@ class LangMemManager:
# Identify speakers # Identify speakers
for conv in chat_history: for conv in chat_history:
speakers.add(conv['speaker']) speakers.add(conv["speaker"])
if len(speakers) != 2: if len(speakers) != 2:
raise ValueError(f"Expected 2 speakers, got {len(speakers)}") raise ValueError(f"Expected 2 speakers, got {len(speakers)}")
@@ -134,50 +128,52 @@ class LangMemManager:
# Add memories for each message # Add memories for each message
for conv in tqdm(chat_history, desc=f"Processing messages {key}", leave=False): for conv in tqdm(chat_history, desc=f"Processing messages {key}", leave=False):
message = f"{conv['timestamp']} | {conv['speaker']}: {conv['text']}" message = f"{conv['timestamp']} | {conv['speaker']}: {conv['text']}"
if conv['speaker'] == speaker1: if conv["speaker"] == speaker1:
agent1.add_memory(message, config) agent1.add_memory(message, config)
elif conv['speaker'] == speaker2: elif conv["speaker"] == speaker2:
agent2.add_memory(message, config) agent2.add_memory(message, config)
else: else:
raise ValueError(f"Expected speaker1 or speaker2, got {conv['speaker']}") raise ValueError(f"Expected speaker1 or speaker2, got {conv['speaker']}")
# Process questions # Process questions
for q in tqdm(questions, desc=f"Processing questions {key}", leave=False): for q in tqdm(questions, desc=f"Processing questions {key}", leave=False):
category = q['category'] category = q["category"]
if int(category) == 5: if int(category) == 5:
continue continue
answer = q['answer'] answer = q["answer"]
question = q['question'] question = q["question"]
response1, speaker1_memory_time = agent1.search_memory(question, config) response1, speaker1_memory_time = agent1.search_memory(question, config)
response2, speaker2_memory_time = agent2.search_memory(question, config) response2, speaker2_memory_time = agent2.search_memory(question, config)
generated_answer, response_time = get_answer( generated_answer, response_time = get_answer(question, speaker1, response1, speaker2, response2)
question, speaker1, response1, speaker2, response2
)
result[key].append({ result[key].append(
"question": question, {
"answer": answer, "question": question,
"response1": response1, "answer": answer,
"response2": response2, "response1": response1,
"category": category, "response2": response2,
"speaker1_memory_time": speaker1_memory_time, "category": category,
"speaker2_memory_time": speaker2_memory_time, "speaker1_memory_time": speaker1_memory_time,
"response_time": response_time, "speaker2_memory_time": speaker2_memory_time,
'response': generated_answer "response_time": response_time,
}) "response": generated_answer,
}
)
return result return result
# Use multiprocessing to process conversations in parallel # Use multiprocessing to process conversations in parallel
with mp.Pool(processes=10) as pool: with mp.Pool(processes=10) as pool:
results = list(tqdm( results = list(
pool.imap(process_conversation, list(self.data.items())), tqdm(
total=len(self.data), pool.imap(process_conversation, list(self.data.items())),
desc="Processing conversations" total=len(self.data),
)) desc="Processing conversations",
)
)
# Combine results from all workers # Combine results from all workers
for result in results: for result in results:
@@ -185,5 +181,5 @@ class LangMemManager:
OUTPUT[key].extend(items) OUTPUT[key].extend(items)
# Save final results # Save final results
with open(output_file_path, 'w') as f: with open(output_file_path, "w") as f:
json.dump(OUTPUT, f, indent=4) json.dump(OUTPUT, f, indent=4)

View File

@@ -13,7 +13,7 @@ load_dotenv()
# Update custom instructions # Update custom instructions
custom_instructions =""" custom_instructions = """
Generate personal memories that follow these guidelines: Generate personal memories that follow these guidelines:
1. Each memory should be self-contained with complete context, including: 1. Each memory should be self-contained with complete context, including:
@@ -47,7 +47,7 @@ class MemoryADD:
self.mem0_client = MemoryClient( self.mem0_client = MemoryClient(
api_key=os.getenv("MEM0_API_KEY"), api_key=os.getenv("MEM0_API_KEY"),
org_id=os.getenv("MEM0_ORGANIZATION_ID"), org_id=os.getenv("MEM0_ORGANIZATION_ID"),
project_id=os.getenv("MEM0_PROJECT_ID") project_id=os.getenv("MEM0_PROJECT_ID"),
) )
self.mem0_client.update_project(custom_instructions=custom_instructions) self.mem0_client.update_project(custom_instructions=custom_instructions)
@@ -59,15 +59,16 @@ class MemoryADD:
self.load_data() self.load_data()
def load_data(self): def load_data(self):
with open(self.data_path, 'r') as f: with open(self.data_path, "r") as f:
self.data = json.load(f) self.data = json.load(f)
return self.data return self.data
def add_memory(self, user_id, message, metadata, retries=3): def add_memory(self, user_id, message, metadata, retries=3):
for attempt in range(retries): for attempt in range(retries):
try: try:
_ = self.mem0_client.add(message, user_id=user_id, version="v2", _ = self.mem0_client.add(
metadata=metadata, enable_graph=self.is_graph) message, user_id=user_id, version="v2", metadata=metadata, enable_graph=self.is_graph
)
return return
except Exception as e: except Exception as e:
if attempt < retries - 1: if attempt < retries - 1:
@@ -78,13 +79,13 @@ class MemoryADD:
def add_memories_for_speaker(self, speaker, messages, timestamp, desc): def add_memories_for_speaker(self, speaker, messages, timestamp, desc):
for i in tqdm(range(0, len(messages), self.batch_size), desc=desc): for i in tqdm(range(0, len(messages), self.batch_size), desc=desc):
batch_messages = messages[i:i+self.batch_size] batch_messages = messages[i : i + self.batch_size]
self.add_memory(speaker, batch_messages, metadata={"timestamp": timestamp}) self.add_memory(speaker, batch_messages, metadata={"timestamp": timestamp})
def process_conversation(self, item, idx): def process_conversation(self, item, idx):
conversation = item['conversation'] conversation = item["conversation"]
speaker_a = conversation['speaker_a'] speaker_a = conversation["speaker_a"]
speaker_b = conversation['speaker_b'] speaker_b = conversation["speaker_b"]
speaker_a_user_id = f"{speaker_a}_{idx}" speaker_a_user_id = f"{speaker_a}_{idx}"
speaker_b_user_id = f"{speaker_b}_{idx}" speaker_b_user_id = f"{speaker_b}_{idx}"
@@ -94,7 +95,7 @@ class MemoryADD:
self.mem0_client.delete_all(user_id=speaker_b_user_id) self.mem0_client.delete_all(user_id=speaker_b_user_id)
for key in conversation.keys(): for key in conversation.keys():
if key in ['speaker_a', 'speaker_b'] or "date" in key or "timestamp" in key: if key in ["speaker_a", "speaker_b"] or "date" in key or "timestamp" in key:
continue continue
date_time_key = key + "_date_time" date_time_key = key + "_date_time"
@@ -104,10 +105,10 @@ class MemoryADD:
messages = [] messages = []
messages_reverse = [] messages_reverse = []
for chat in chats: for chat in chats:
if chat['speaker'] == speaker_a: if chat["speaker"] == speaker_a:
messages.append({"role": "user", "content": f"{speaker_a}: {chat['text']}"}) messages.append({"role": "user", "content": f"{speaker_a}: {chat['text']}"})
messages_reverse.append({"role": "assistant", "content": f"{speaker_a}: {chat['text']}"}) messages_reverse.append({"role": "assistant", "content": f"{speaker_a}: {chat['text']}"})
elif chat['speaker'] == speaker_b: elif chat["speaker"] == speaker_b:
messages.append({"role": "assistant", "content": f"{speaker_b}: {chat['text']}"}) messages.append({"role": "assistant", "content": f"{speaker_b}: {chat['text']}"})
messages_reverse.append({"role": "user", "content": f"{speaker_b}: {chat['text']}"}) messages_reverse.append({"role": "user", "content": f"{speaker_b}: {chat['text']}"})
else: else:
@@ -116,11 +117,11 @@ class MemoryADD:
# add memories for the two users on different threads # add memories for the two users on different threads
thread_a = threading.Thread( thread_a = threading.Thread(
target=self.add_memories_for_speaker, target=self.add_memories_for_speaker,
args=(speaker_a_user_id, messages, timestamp, "Adding Memories for Speaker A") args=(speaker_a_user_id, messages, timestamp, "Adding Memories for Speaker A"),
) )
thread_b = threading.Thread( thread_b = threading.Thread(
target=self.add_memories_for_speaker, target=self.add_memories_for_speaker,
args=(speaker_b_user_id, messages_reverse, timestamp, "Adding Memories for Speaker B") args=(speaker_b_user_id, messages_reverse, timestamp, "Adding Memories for Speaker B"),
) )
thread_a.start() thread_a.start()
@@ -134,10 +135,7 @@ class MemoryADD:
if not self.data: if not self.data:
raise ValueError("No data loaded. Please set data_path and call load_data() first.") raise ValueError("No data loaded. Please set data_path and call load_data() first.")
with ThreadPoolExecutor(max_workers=max_workers) as executor: with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [ futures = [executor.submit(self.process_conversation, item, idx) for idx, item in enumerate(self.data)]
executor.submit(self.process_conversation, item, idx)
for idx, item in enumerate(self.data)
]
for future in futures: for future in futures:
future.result() future.result()

View File

@@ -16,12 +16,11 @@ load_dotenv()
class MemorySearch: class MemorySearch:
def __init__(self, output_path="results.json", top_k=10, filter_memories=False, is_graph=False):
def __init__(self, output_path='results.json', top_k=10, filter_memories=False, is_graph=False):
self.mem0_client = MemoryClient( self.mem0_client = MemoryClient(
api_key=os.getenv("MEM0_API_KEY"), api_key=os.getenv("MEM0_API_KEY"),
org_id=os.getenv("MEM0_ORGANIZATION_ID"), org_id=os.getenv("MEM0_ORGANIZATION_ID"),
project_id=os.getenv("MEM0_PROJECT_ID") project_id=os.getenv("MEM0_PROJECT_ID"),
) )
self.top_k = top_k self.top_k = top_k
self.openai_client = OpenAI() self.openai_client = OpenAI()
@@ -42,11 +41,18 @@ class MemorySearch:
try: try:
if self.is_graph: if self.is_graph:
print("Searching with graph") print("Searching with graph")
memories = self.mem0_client.search(query, user_id=user_id, top_k=self.top_k, memories = self.mem0_client.search(
filter_memories=self.filter_memories, enable_graph=True, output_format='v1.1') query,
user_id=user_id,
top_k=self.top_k,
filter_memories=self.filter_memories,
enable_graph=True,
output_format="v1.1",
)
else: else:
memories = self.mem0_client.search(query, user_id=user_id, top_k=self.top_k, memories = self.mem0_client.search(
filter_memories=self.filter_memories) query, user_id=user_id, top_k=self.top_k, filter_memories=self.filter_memories
)
break break
except Exception as e: except Exception as e:
print("Retrying...") print("Retrying...")
@@ -57,64 +63,86 @@ class MemorySearch:
end_time = time.time() end_time = time.time()
if not self.is_graph: if not self.is_graph:
semantic_memories = [{'memory': memory['memory'], semantic_memories = [
'timestamp': memory['metadata']['timestamp'], {
'score': round(memory['score'], 2)} "memory": memory["memory"],
for memory in memories] "timestamp": memory["metadata"]["timestamp"],
"score": round(memory["score"], 2),
}
for memory in memories
]
graph_memories = None graph_memories = None
else: else:
semantic_memories = [{'memory': memory['memory'], semantic_memories = [
'timestamp': memory['metadata']['timestamp'], {
'score': round(memory['score'], 2)} for memory in memories['results']] "memory": memory["memory"],
graph_memories = [{"source": relation['source'], "relationship": relation['relationship'], "target": relation['target']} for relation in memories['relations']] "timestamp": memory["metadata"]["timestamp"],
"score": round(memory["score"], 2),
}
for memory in memories["results"]
]
graph_memories = [
{"source": relation["source"], "relationship": relation["relationship"], "target": relation["target"]}
for relation in memories["relations"]
]
return semantic_memories, graph_memories, end_time - start_time return semantic_memories, graph_memories, end_time - start_time
def answer_question(self, speaker_1_user_id, speaker_2_user_id, question, answer, category): def answer_question(self, speaker_1_user_id, speaker_2_user_id, question, answer, category):
speaker_1_memories, speaker_1_graph_memories, speaker_1_memory_time = self.search_memory(speaker_1_user_id, question) speaker_1_memories, speaker_1_graph_memories, speaker_1_memory_time = self.search_memory(
speaker_2_memories, speaker_2_graph_memories, speaker_2_memory_time = self.search_memory(speaker_2_user_id, question) speaker_1_user_id, question
)
speaker_2_memories, speaker_2_graph_memories, speaker_2_memory_time = self.search_memory(
speaker_2_user_id, question
)
search_1_memory = [f"{item['timestamp']}: {item['memory']}" search_1_memory = [f"{item['timestamp']}: {item['memory']}" for item in speaker_1_memories]
for item in speaker_1_memories] search_2_memory = [f"{item['timestamp']}: {item['memory']}" for item in speaker_2_memories]
search_2_memory = [f"{item['timestamp']}: {item['memory']}"
for item in speaker_2_memories]
template = Template(self.ANSWER_PROMPT) template = Template(self.ANSWER_PROMPT)
answer_prompt = template.render( answer_prompt = template.render(
speaker_1_user_id=speaker_1_user_id.split('_')[0], speaker_1_user_id=speaker_1_user_id.split("_")[0],
speaker_2_user_id=speaker_2_user_id.split('_')[0], speaker_2_user_id=speaker_2_user_id.split("_")[0],
speaker_1_memories=json.dumps(search_1_memory, indent=4), speaker_1_memories=json.dumps(search_1_memory, indent=4),
speaker_2_memories=json.dumps(search_2_memory, indent=4), speaker_2_memories=json.dumps(search_2_memory, indent=4),
speaker_1_graph_memories=json.dumps(speaker_1_graph_memories, indent=4), speaker_1_graph_memories=json.dumps(speaker_1_graph_memories, indent=4),
speaker_2_graph_memories=json.dumps(speaker_2_graph_memories, indent=4), speaker_2_graph_memories=json.dumps(speaker_2_graph_memories, indent=4),
question=question question=question,
) )
t1 = time.time() t1 = time.time()
response = self.openai_client.chat.completions.create( response = self.openai_client.chat.completions.create(
model=os.getenv("MODEL"), model=os.getenv("MODEL"), messages=[{"role": "system", "content": answer_prompt}], temperature=0.0
messages=[
{"role": "system", "content": answer_prompt}
],
temperature=0.0
) )
t2 = time.time() t2 = time.time()
response_time = t2 - t1 response_time = t2 - t1
return response.choices[0].message.content, speaker_1_memories, speaker_2_memories, speaker_1_memory_time, speaker_2_memory_time, speaker_1_graph_memories, speaker_2_graph_memories, response_time return (
response.choices[0].message.content,
speaker_1_memories,
speaker_2_memories,
speaker_1_memory_time,
speaker_2_memory_time,
speaker_1_graph_memories,
speaker_2_graph_memories,
response_time,
)
def process_question(self, val, speaker_a_user_id, speaker_b_user_id): def process_question(self, val, speaker_a_user_id, speaker_b_user_id):
question = val.get('question', '') question = val.get("question", "")
answer = val.get('answer', '') answer = val.get("answer", "")
category = val.get('category', -1) category = val.get("category", -1)
evidence = val.get('evidence', []) evidence = val.get("evidence", [])
adversarial_answer = val.get('adversarial_answer', '') adversarial_answer = val.get("adversarial_answer", "")
response, speaker_1_memories, speaker_2_memories, speaker_1_memory_time, speaker_2_memory_time, speaker_1_graph_memories, speaker_2_graph_memories, response_time = self.answer_question( (
speaker_a_user_id, response,
speaker_b_user_id, speaker_1_memories,
question, speaker_2_memories,
answer, speaker_1_memory_time,
category speaker_2_memory_time,
) speaker_1_graph_memories,
speaker_2_graph_memories,
response_time,
) = self.answer_question(speaker_a_user_id, speaker_b_user_id, question, answer, category)
result = { result = {
"question": question, "question": question,
@@ -125,67 +153,63 @@ class MemorySearch:
"adversarial_answer": adversarial_answer, "adversarial_answer": adversarial_answer,
"speaker_1_memories": speaker_1_memories, "speaker_1_memories": speaker_1_memories,
"speaker_2_memories": speaker_2_memories, "speaker_2_memories": speaker_2_memories,
'num_speaker_1_memories': len(speaker_1_memories), "num_speaker_1_memories": len(speaker_1_memories),
'num_speaker_2_memories': len(speaker_2_memories), "num_speaker_2_memories": len(speaker_2_memories),
'speaker_1_memory_time': speaker_1_memory_time, "speaker_1_memory_time": speaker_1_memory_time,
'speaker_2_memory_time': speaker_2_memory_time, "speaker_2_memory_time": speaker_2_memory_time,
"speaker_1_graph_memories": speaker_1_graph_memories, "speaker_1_graph_memories": speaker_1_graph_memories,
"speaker_2_graph_memories": speaker_2_graph_memories, "speaker_2_graph_memories": speaker_2_graph_memories,
"response_time": response_time "response_time": response_time,
} }
# Save results after each question is processed # Save results after each question is processed
with open(self.output_path, 'w') as f: with open(self.output_path, "w") as f:
json.dump(self.results, f, indent=4) json.dump(self.results, f, indent=4)
return result return result
def process_data_file(self, file_path): def process_data_file(self, file_path):
with open(file_path, 'r') as f: with open(file_path, "r") as f:
data = json.load(f) data = json.load(f)
for idx, item in tqdm(enumerate(data), total=len(data), desc="Processing conversations"): for idx, item in tqdm(enumerate(data), total=len(data), desc="Processing conversations"):
qa = item['qa'] qa = item["qa"]
conversation = item['conversation'] conversation = item["conversation"]
speaker_a = conversation['speaker_a'] speaker_a = conversation["speaker_a"]
speaker_b = conversation['speaker_b'] speaker_b = conversation["speaker_b"]
speaker_a_user_id = f"{speaker_a}_{idx}" speaker_a_user_id = f"{speaker_a}_{idx}"
speaker_b_user_id = f"{speaker_b}_{idx}" speaker_b_user_id = f"{speaker_b}_{idx}"
for question_item in tqdm(qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False): for question_item in tqdm(
result = self.process_question( qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False
question_item, ):
speaker_a_user_id, result = self.process_question(question_item, speaker_a_user_id, speaker_b_user_id)
speaker_b_user_id
)
self.results[idx].append(result) self.results[idx].append(result)
# Save results after each question is processed # Save results after each question is processed
with open(self.output_path, 'w') as f: with open(self.output_path, "w") as f:
json.dump(self.results, f, indent=4) json.dump(self.results, f, indent=4)
# Final save at the end # Final save at the end
with open(self.output_path, 'w') as f: with open(self.output_path, "w") as f:
json.dump(self.results, f, indent=4) json.dump(self.results, f, indent=4)
def process_questions_parallel(self, qa_list, speaker_a_user_id, speaker_b_user_id, max_workers=1): def process_questions_parallel(self, qa_list, speaker_a_user_id, speaker_b_user_id, max_workers=1):
def process_single_question(val): def process_single_question(val):
result = self.process_question(val, speaker_a_user_id, speaker_b_user_id) result = self.process_question(val, speaker_a_user_id, speaker_b_user_id)
# Save results after each question is processed # Save results after each question is processed
with open(self.output_path, 'w') as f: with open(self.output_path, "w") as f:
json.dump(self.results, f, indent=4) json.dump(self.results, f, indent=4)
return result return result
with ThreadPoolExecutor(max_workers=max_workers) as executor: with ThreadPoolExecutor(max_workers=max_workers) as executor:
results = list(tqdm( results = list(
executor.map(process_single_question, qa_list), tqdm(executor.map(process_single_question, qa_list), total=len(qa_list), desc="Answering Questions")
total=len(qa_list), )
desc="Answering Questions"
))
# Final save at the end # Final save at the end
with open(self.output_path, 'w') as f: with open(self.output_path, "w") as f:
json.dump(self.results, f, indent=4) json.dump(self.results, f, indent=4)
return results return results

View File

@@ -59,23 +59,19 @@ class OpenAIPredict:
self.results = defaultdict(list) self.results = defaultdict(list)
def search_memory(self, idx): def search_memory(self, idx):
with open(f"memories/{idx}.txt", "r") as file:
with open(f'memories/{idx}.txt', 'r') as file:
memories = file.read() memories = file.read()
return memories, 0 return memories, 0
def process_question(self, val, idx): def process_question(self, val, idx):
question = val.get('question', '') question = val.get("question", "")
answer = val.get('answer', '') answer = val.get("answer", "")
category = val.get('category', -1) category = val.get("category", -1)
evidence = val.get('evidence', []) evidence = val.get("evidence", [])
adversarial_answer = val.get('adversarial_answer', '') adversarial_answer = val.get("adversarial_answer", "")
response, search_memory_time, response_time, context = self.answer_question( response, search_memory_time, response_time, context = self.answer_question(idx, question)
idx,
question
)
result = { result = {
"question": question, "question": question,
@@ -86,7 +82,7 @@ class OpenAIPredict:
"adversarial_answer": adversarial_answer, "adversarial_answer": adversarial_answer,
"search_memory_time": search_memory_time, "search_memory_time": search_memory_time,
"response_time": response_time, "response_time": response_time,
"context": context "context": context,
} }
return result return result
@@ -95,43 +91,35 @@ class OpenAIPredict:
memories, search_memory_time = self.search_memory(idx) memories, search_memory_time = self.search_memory(idx)
template = Template(ANSWER_PROMPT) template = Template(ANSWER_PROMPT)
answer_prompt = template.render( answer_prompt = template.render(memories=memories, question=question)
memories=memories,
question=question
)
t1 = time.time() t1 = time.time()
response = self.openai_client.chat.completions.create( response = self.openai_client.chat.completions.create(
model=os.getenv("MODEL"), model=os.getenv("MODEL"), messages=[{"role": "system", "content": answer_prompt}], temperature=0.0
messages=[
{"role": "system", "content": answer_prompt}
],
temperature=0.0
) )
t2 = time.time() t2 = time.time()
response_time = t2 - t1 response_time = t2 - t1
return response.choices[0].message.content, search_memory_time, response_time, memories return response.choices[0].message.content, search_memory_time, response_time, memories
def process_data_file(self, file_path, output_file_path): def process_data_file(self, file_path, output_file_path):
with open(file_path, 'r') as f: with open(file_path, "r") as f:
data = json.load(f) data = json.load(f)
for idx, item in tqdm(enumerate(data), total=len(data), desc="Processing conversations"): for idx, item in tqdm(enumerate(data), total=len(data), desc="Processing conversations"):
qa = item['qa'] qa = item["qa"]
for question_item in tqdm(qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False): for question_item in tqdm(
result = self.process_question( qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False
question_item, ):
idx result = self.process_question(question_item, idx)
)
self.results[idx].append(result) self.results[idx].append(result)
# Save results after each question is processed # Save results after each question is processed
with open(output_file_path, 'w') as f: with open(output_file_path, "w") as f:
json.dump(self.results, f, indent=4) json.dump(self.results, f, indent=4)
# Final save at the end # Final save at the end
with open(output_file_path, 'w') as f: with open(output_file_path, "w") as f:
json.dump(self.results, f, indent=4) json.dump(self.results, f, indent=4)
@@ -141,4 +129,3 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
openai_predict = OpenAIPredict() openai_predict = OpenAIPredict()
openai_predict.process_data_file("../../dataset/locomo10.json", args.output_file_path) openai_predict.process_data_file("../../dataset/locomo10.json", args.output_file_path)

View File

@@ -33,10 +33,7 @@ class RAGManager:
def generate_response(self, question, context): def generate_response(self, question, context):
template = Template(PROMPT) template = Template(PROMPT)
prompt = template.render( prompt = template.render(CONTEXT=context, QUESTION=question)
CONTEXT=context,
QUESTION=question
)
max_retries = 3 max_retries = 3
retries = 0 retries = 0
@@ -47,19 +44,21 @@ class RAGManager:
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.model, model=self.model,
messages=[ messages=[
{"role": "system", {
"content": "You are a helpful assistant that can answer " "role": "system",
"questions based on the provided context." "content": "You are a helpful assistant that can answer "
"If the question involves timing, use the conversation date for reference." "questions based on the provided context."
"Provide the shortest possible answer." "If the question involves timing, use the conversation date for reference."
"Use words directly from the conversation when possible." "Provide the shortest possible answer."
"Avoid using subjects in your answer."}, "Use words directly from the conversation when possible."
{"role": "user", "content": prompt} "Avoid using subjects in your answer.",
},
{"role": "user", "content": prompt},
], ],
temperature=0 temperature=0,
) )
t2 = time.time() t2 = time.time()
return response.choices[0].message.content.strip(), t2-t1 return response.choices[0].message.content.strip(), t2 - t1
except Exception as e: except Exception as e:
retries += 1 retries += 1
if retries > max_retries: if retries > max_retries:
@@ -69,21 +68,16 @@ class RAGManager:
def clean_chat_history(self, chat_history): def clean_chat_history(self, chat_history):
cleaned_chat_history = "" cleaned_chat_history = ""
for c in chat_history: for c in chat_history:
cleaned_chat_history += (f"{c['timestamp']} | {c['speaker']}: " cleaned_chat_history += f"{c['timestamp']} | {c['speaker']}: " f"{c['text']}\n"
f"{c['text']}\n")
return cleaned_chat_history return cleaned_chat_history
def calculate_embedding(self, document): def calculate_embedding(self, document):
response = self.client.embeddings.create( response = self.client.embeddings.create(model=os.getenv("EMBEDDING_MODEL"), input=document)
model=os.getenv("EMBEDDING_MODEL"),
input=document
)
return response.data[0].embedding return response.data[0].embedding
def calculate_similarity(self, embedding1, embedding2): def calculate_similarity(self, embedding1, embedding2):
return np.dot(embedding1, embedding2) / ( return np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
def search(self, query, chunks, embeddings, k=1): def search(self, query, chunks, embeddings, k=1):
""" """
@@ -101,10 +95,7 @@ class RAGManager:
""" """
t1 = time.time() t1 = time.time()
query_embedding = self.calculate_embedding(query) query_embedding = self.calculate_embedding(query)
similarities = [ similarities = [self.calculate_similarity(query_embedding, embedding) for embedding in embeddings]
self.calculate_similarity(query_embedding, embedding)
for embedding in embeddings
]
# Get indices of top-k most similar chunks # Get indices of top-k most similar chunks
if k == 1: if k == 1:
@@ -118,7 +109,7 @@ class RAGManager:
combined_chunks = "\n<->\n".join([chunks[i] for i in top_indices]) combined_chunks = "\n<->\n".join([chunks[i] for i in top_indices])
t2 = time.time() t2 = time.time()
return combined_chunks, t2-t1 return combined_chunks, t2 - t1
def create_chunks(self, chat_history, chunk_size=500): def create_chunks(self, chat_history, chunk_size=500):
""" """
@@ -139,7 +130,7 @@ class RAGManager:
# Split into chunks based on token count # Split into chunks based on token count
for i in range(0, len(tokens), chunk_size): for i in range(0, len(tokens), chunk_size):
chunk_tokens = tokens[i:i+chunk_size] chunk_tokens = tokens[i : i + chunk_size]
chunk = encoding.decode(chunk_tokens) chunk = encoding.decode(chunk_tokens)
chunks.append(chunk) chunks.append(chunk)
@@ -159,13 +150,9 @@ class RAGManager:
chat_history = value["conversation"] chat_history = value["conversation"]
questions = value["question"] questions = value["question"]
chunks, embeddings = self.create_chunks( chunks, embeddings = self.create_chunks(chat_history, self.chunk_size)
chat_history, self.chunk_size
)
for item in tqdm( for item in tqdm(questions, desc="Answering questions", leave=False):
questions, desc="Answering questions", leave=False
):
question = item["question"] question = item["question"]
answer = item.get("answer", "") answer = item.get("answer", "")
category = item["category"] category = item["category"]
@@ -174,22 +161,20 @@ class RAGManager:
context = chunks[0] context = chunks[0]
search_time = 0 search_time = 0
else: else:
context, search_time = self.search( context, search_time = self.search(question, chunks, embeddings, k=self.k)
question, chunks, embeddings, k=self.k response, response_time = self.generate_response(question, context)
)
response, response_time = self.generate_response(
question, context
)
FINAL_RESULTS[key].append({ FINAL_RESULTS[key].append(
"question": question, {
"answer": answer, "question": question,
"category": category, "answer": answer,
"context": context, "category": category,
"response": response, "context": context,
"search_time": search_time, "response": response,
"response_time": response_time, "search_time": search_time,
}) "response_time": response_time,
}
)
with open(output_file_path, "w+") as f: with open(output_file_path, "w+") as f:
json.dump(FINAL_RESULTS, f, indent=4) json.dump(FINAL_RESULTS, f, indent=4)

View File

@@ -1,12 +1,3 @@
TECHNIQUES = [ TECHNIQUES = ["mem0", "rag", "langmem", "zep", "openai"]
"mem0",
"rag",
"langmem",
"zep",
"openai"
]
METHODS = [ METHODS = ["add", "search"]
"add",
"search"
]

View File

@@ -19,12 +19,12 @@ class ZepAdd:
self.load_data() self.load_data()
def load_data(self): def load_data(self):
with open(self.data_path, 'r') as f: with open(self.data_path, "r") as f:
self.data = json.load(f) self.data = json.load(f)
return self.data return self.data
def process_conversation(self, run_id, item, idx): def process_conversation(self, run_id, item, idx):
conversation = item['conversation'] conversation = item["conversation"]
user_id = f"run_id_{run_id}_experiment_user_{idx}" user_id = f"run_id_{run_id}_experiment_user_{idx}"
session_id = f"run_id_{run_id}_experiment_session_{idx}" session_id = f"run_id_{run_id}_experiment_session_{idx}"
@@ -41,7 +41,7 @@ class ZepAdd:
print("Starting to add memories... for user", user_id) print("Starting to add memories... for user", user_id)
for key in tqdm(conversation.keys(), desc=f"Processing user {user_id}"): for key in tqdm(conversation.keys(), desc=f"Processing user {user_id}"):
if key in ['speaker_a', 'speaker_b'] or "date" in key: if key in ["speaker_a", "speaker_b"] or "date" in key:
continue continue
date_time_key = key + "_date_time" date_time_key = key + "_date_time"
@@ -51,11 +51,13 @@ class ZepAdd:
for chat in tqdm(chats, desc=f"Adding chats for {key}", leave=False): for chat in tqdm(chats, desc=f"Adding chats for {key}", leave=False):
self.zep_client.memory.add( self.zep_client.memory.add(
session_id=session_id, session_id=session_id,
messages=[Message( messages=[
role=chat['speaker'], Message(
role_type="user", role=chat["speaker"],
content=f"{timestamp}: {chat['text']}", role_type="user",
)] content=f"{timestamp}: {chat['text']}",
)
],
) )
def process_all_conversations(self, run_id): def process_all_conversations(self, run_id):
@@ -71,4 +73,4 @@ if __name__ == "__main__":
parser.add_argument("--run_id", type=str, required=True) parser.add_argument("--run_id", type=str, required=True)
args = parser.parse_args() args = parser.parse_args()
zep_add = ZepAdd(data_path="../../dataset/locomo10.json") zep_add = ZepAdd(data_path="../../dataset/locomo10.json")
zep_add.process_all_conversations(args.run_id) zep_add.process_all_conversations(args.run_id)

View File

@@ -42,9 +42,9 @@ class ZepSearch:
return f"{edge.valid_at if edge.valid_at else 'date unknown'} - {(edge.invalid_at if edge.invalid_at else 'present')}" return f"{edge.valid_at if edge.valid_at else 'date unknown'} - {(edge.invalid_at if edge.invalid_at else 'present')}"
def compose_search_context(self, edges: list[EntityEdge], nodes: list[EntityNode]) -> str: def compose_search_context(self, edges: list[EntityEdge], nodes: list[EntityNode]) -> str:
facts = [f' - {edge.fact} ({self.format_edge_date_range(edge)})' for edge in edges] facts = [f" - {edge.fact} ({self.format_edge_date_range(edge)})" for edge in edges]
entities = [f' - {node.name}: {node.summary}' for node in nodes] entities = [f" - {node.name}: {node.summary}" for node in nodes]
return TEMPLATE.format(facts='\n'.join(facts), entities='\n'.join(entities)) return TEMPLATE.format(facts="\n".join(facts), entities="\n".join(entities))
def search_memory(self, run_id, idx, query, max_retries=3, retry_delay=1): def search_memory(self, run_id, idx, query, max_retries=3, retry_delay=1):
start_time = time.time() start_time = time.time()
@@ -52,8 +52,14 @@ class ZepSearch:
while retries < max_retries: while retries < max_retries:
try: try:
user_id = f"run_id_{run_id}_experiment_user_{idx}" user_id = f"run_id_{run_id}_experiment_user_{idx}"
edges_results = (self.zep_client.graph.search(user_id=user_id, reranker='cross_encoder', query=query, scope='edges', limit=20)).edges edges_results = (
node_results = (self.zep_client.graph.search(user_id=user_id, reranker='rrf', query=query, scope='nodes', limit=20)).nodes self.zep_client.graph.search(
user_id=user_id, reranker="cross_encoder", query=query, scope="edges", limit=20
)
).edges
node_results = (
self.zep_client.graph.search(user_id=user_id, reranker="rrf", query=query, scope="nodes", limit=20)
).nodes
context = self.compose_search_context(edges_results, node_results) context = self.compose_search_context(edges_results, node_results)
break break
except Exception as e: except Exception as e:
@@ -68,17 +74,13 @@ class ZepSearch:
return context, end_time - start_time return context, end_time - start_time
def process_question(self, run_id, val, idx): def process_question(self, run_id, val, idx):
question = val.get('question', '') question = val.get("question", "")
answer = val.get('answer', '') answer = val.get("answer", "")
category = val.get('category', -1) category = val.get("category", -1)
evidence = val.get('evidence', []) evidence = val.get("evidence", [])
adversarial_answer = val.get('adversarial_answer', '') adversarial_answer = val.get("adversarial_answer", "")
response, search_memory_time, response_time, context = self.answer_question( response, search_memory_time, response_time, context = self.answer_question(run_id, idx, question)
run_id,
idx,
question
)
result = { result = {
"question": question, "question": question,
@@ -89,7 +91,7 @@ class ZepSearch:
"adversarial_answer": adversarial_answer, "adversarial_answer": adversarial_answer,
"search_memory_time": search_memory_time, "search_memory_time": search_memory_time,
"response_time": response_time, "response_time": response_time,
"context": context "context": context,
} }
return result return result
@@ -98,44 +100,35 @@ class ZepSearch:
context, search_memory_time = self.search_memory(run_id, idx, question) context, search_memory_time = self.search_memory(run_id, idx, question)
template = Template(ANSWER_PROMPT_ZEP) template = Template(ANSWER_PROMPT_ZEP)
answer_prompt = template.render( answer_prompt = template.render(memories=context, question=question)
memories=context,
question=question
)
t1 = time.time() t1 = time.time()
response = self.openai_client.chat.completions.create( response = self.openai_client.chat.completions.create(
model=os.getenv("MODEL"), model=os.getenv("MODEL"), messages=[{"role": "system", "content": answer_prompt}], temperature=0.0
messages=[
{"role": "system", "content": answer_prompt}
],
temperature=0.0
) )
t2 = time.time() t2 = time.time()
response_time = t2 - t1 response_time = t2 - t1
return response.choices[0].message.content, search_memory_time, response_time, context return response.choices[0].message.content, search_memory_time, response_time, context
def process_data_file(self, file_path, run_id, output_file_path): def process_data_file(self, file_path, run_id, output_file_path):
with open(file_path, 'r') as f: with open(file_path, "r") as f:
data = json.load(f) data = json.load(f)
for idx, item in tqdm(enumerate(data), total=len(data), desc="Processing conversations"): for idx, item in tqdm(enumerate(data), total=len(data), desc="Processing conversations"):
qa = item['qa'] qa = item["qa"]
for question_item in tqdm(qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False): for question_item in tqdm(
result = self.process_question( qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False
run_id, ):
question_item, result = self.process_question(run_id, question_item, idx)
idx
)
self.results[idx].append(result) self.results[idx].append(result)
# Save results after each question is processed # Save results after each question is processed
with open(output_file_path, 'w') as f: with open(output_file_path, "w") as f:
json.dump(self.results, f, indent=4) json.dump(self.results, f, indent=4)
# Final save at the end # Final save at the end
with open(output_file_path, 'w') as f: with open(output_file_path, "w") as f:
json.dump(self.results, f, indent=4) json.dump(self.results, f, indent=4)

View File

@@ -56,9 +56,7 @@
"\n", "\n",
"import os\n", "import os\n",
"\n", "\n",
"os.environ[\"OPENAI_API_KEY\"] = (\n", "os.environ[\"OPENAI_API_KEY\"] = \"\""
" \"\"\n",
")"
] ]
}, },
{ {
@@ -149,7 +147,7 @@
" \"role\": \"assistant\",\n", " \"role\": \"assistant\",\n",
" \"content\": \"Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future.\",\n", " \"content\": \"Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future.\",\n",
" },\n", " },\n",
"]\n" "]"
] ]
}, },
{ {
@@ -166,9 +164,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# Store inferred memories (default behavior)\n", "# Store inferred memories (default behavior)\n",
"result = m.add(\n", "result = m.add(messages, user_id=\"alice\", metadata={\"category\": \"movie_recommendations\"})"
" messages, user_id=\"alice\", metadata={\"category\": \"movie_recommendations\"}\n",
")"
] ]
}, },
{ {

View File

@@ -20,19 +20,19 @@ agent = Agent(
name="Fitness Agent", name="Fitness Agent",
model=OpenAIChat(id="gpt-4o"), model=OpenAIChat(id="gpt-4o"),
description="You are a helpful fitness assistant who remembers past logs and gives personalized suggestions for Anish's training and diet.", description="You are a helpful fitness assistant who remembers past logs and gives personalized suggestions for Anish's training and diet.",
markdown=True markdown=True,
) )
# Store user preferences as memory # Store user preferences as memory
def store_user_preferences(conversation: list, user_id: str = USER_ID): def store_user_preferences(conversation: list, user_id: str = USER_ID):
"""Store user preferences from conversation history""" """Store user preferences from conversation history"""
memory_client.add(conversation, user_id=user_id, output_format='v1.1') memory_client.add(conversation, user_id=user_id, output_format="v1.1")
# Memory-aware assistant function # Memory-aware assistant function
def fitness_coach(user_input: str, user_id: str = USER_ID): def fitness_coach(user_input: str, user_id: str = USER_ID):
memories = memory_client.search(user_input, user_id=user_id) # Search relevant memories bases on user query memories = memory_client.search(user_input, user_id=user_id) # Search relevant memories bases on user query
memory_context = "\n".join(f"- {m['memory']}" for m in memories) memory_context = "\n".join(f"- {m['memory']}" for m in memories)
prompt = f"""You are a fitness assistant who helps Anish with his training, recovery, and diet. You have long-term memory of his health, routines, preferences, and past conversations. prompt = f"""You are a fitness assistant who helps Anish with his training, recovery, and diet. You have long-term memory of his health, routines, preferences, and past conversations.
@@ -48,113 +48,66 @@ User query:
memory_client.add(f"User: {user_input}\nAssistant: {response.content}", user_id=user_id) memory_client.add(f"User: {user_input}\nAssistant: {response.content}", user_id=user_id)
return response.content return response.content
# -------------------------------------------------- # --------------------------------------------------
# Store user preferences and memories # Store user preferences and memories
messages = [ messages = [
{ {
"role": "user", "role": "user",
"content": "Hi, Im Anish. I'm 26 years old, 5'10\", and weigh 72kg. I started working out 6 months ago with the goal of building lean muscle." "content": "Hi, Im Anish. I'm 26 years old, 5'10\", and weigh 72kg. I started working out 6 months ago with the goal of building lean muscle.",
}, },
{ {
"role": "assistant", "role": "assistant",
"content": "Got it — you're 26, 5'10\", 72kg, and on a lean muscle journey. Started gym 6 months ago." "content": "Got it — you're 26, 5'10\", 72kg, and on a lean muscle journey. Started gym 6 months ago.",
}, },
{ {
"role": "user", "role": "user",
"content": "I follow a push-pull-legs routine and train 5 times a week. My rest days are Wednesday and Sunday." "content": "I follow a push-pull-legs routine and train 5 times a week. My rest days are Wednesday and Sunday.",
}, },
{ {
"role": "assistant", "role": "assistant",
"content": "Understood — push-pull-legs split, training 5x/week with rest on Wednesdays and Sundays." "content": "Understood — push-pull-legs split, training 5x/week with rest on Wednesdays and Sundays.",
}, },
{"role": "user", "content": "After push days, I usually eat high-protein and moderate-carb meals to recover."},
{"role": "assistant", "content": "Noted — high-protein, moderate-carb meals after push workouts."},
{"role": "user", "content": "For pull days, I take whey protein and eat a banana after training."},
{"role": "assistant", "content": "Logged — whey protein and banana post pull workouts."},
{"role": "user", "content": "On leg days, I make sure to have complex carbs like rice or oats."},
{"role": "assistant", "content": "Noted — complex carbs like rice and oats are part of your leg day meals."},
{ {
"role": "user", "role": "user",
"content": "After push days, I usually eat high-protein and moderate-carb meals to recover." "content": "I often feel sore after leg days, so I use turmeric milk and magnesium to help with recovery.",
},
{
"role": "assistant",
"content": "Noted — high-protein, moderate-carb meals after push workouts."
}, },
{"role": "assistant", "content": "I'll remember turmeric milk and magnesium as part of your leg day recovery."},
{ {
"role": "user", "role": "user",
"content": "For pull days, I take whey protein and eat a banana after training." "content": "Last push day, I did 3x8 bench press at 60kg, 4x12 overhead press, and dips. Felt fatigued after.",
}, },
{ {
"role": "assistant", "role": "assistant",
"content": "Logged — whey protein and banana post pull workouts." "content": "Push day logged — 60kg bench, overhead press, dips. You felt fatigued afterward.",
}, },
{"role": "user", "content": "I prefer light dinners post-workout like tofu, soup, and vegetables."},
{"role": "assistant", "content": "Got it — light dinners post-workout: tofu, soup, and veggies."},
{ {
"role": "user", "role": "user",
"content": "On leg days, I make sure to have complex carbs like rice or oats." "content": "I have mild lactose intolerance, so I avoid dairy. I use almond milk or lactose-free whey.",
},
{
"role": "assistant",
"content": "Noted — complex carbs like rice and oats are part of your leg day meals."
}, },
{"role": "assistant", "content": "Understood — avoiding regular dairy, using almond milk and lactose-free whey."},
{ {
"role": "user", "role": "user",
"content": "I often feel sore after leg days, so I use turmeric milk and magnesium to help with recovery." "content": "I get occasional knee pain, so I avoid deep squats and do more hamstring curls and glute bridges on leg days.",
}, },
{ {
"role": "assistant", "role": "assistant",
"content": "I'll remember turmeric milk and magnesium as part of your leg day recovery." "content": "Noted — due to knee discomfort, you substitute deep squats with curls and glute bridges.",
},
{
"role": "user",
"content": "Last push day, I did 3x8 bench press at 60kg, 4x12 overhead press, and dips. Felt fatigued after."
},
{
"role": "assistant",
"content": "Push day logged — 60kg bench, overhead press, dips. You felt fatigued afterward."
},
{
"role": "user",
"content": "I prefer light dinners post-workout like tofu, soup, and vegetables."
},
{
"role": "assistant",
"content": "Got it — light dinners post-workout: tofu, soup, and veggies."
},
{
"role": "user",
"content": "I have mild lactose intolerance, so I avoid dairy. I use almond milk or lactose-free whey."
},
{
"role": "assistant",
"content": "Understood — avoiding regular dairy, using almond milk and lactose-free whey."
},
{
"role": "user",
"content": "I get occasional knee pain, so I avoid deep squats and do more hamstring curls and glute bridges on leg days."
},
{
"role": "assistant",
"content": "Noted — due to knee discomfort, you substitute deep squats with curls and glute bridges."
},
{
"role": "user",
"content": "I track sleep and notice poor performance when I sleep less than 6 hours."
},
{
"role": "assistant",
"content": "Logged — performance drops when you get under 6 hours of sleep."
},
{
"role": "user",
"content": "I take magnesium supplements to help with muscle recovery and sleep quality."
},
{
"role": "assistant",
"content": "Remembered — magnesium helps you with recovery and sleep."
},
{
"role": "user",
"content": "I avoid caffeine after 4 PM because it affects my sleep."
},
{
"role": "assistant",
"content": "Got it — you avoid caffeine post-4 PM to protect your sleep."
}, },
{"role": "user", "content": "I track sleep and notice poor performance when I sleep less than 6 hours."},
{"role": "assistant", "content": "Logged — performance drops when you get under 6 hours of sleep."},
{"role": "user", "content": "I take magnesium supplements to help with muscle recovery and sleep quality."},
{"role": "assistant", "content": "Remembered — magnesium helps you with recovery and sleep."},
{"role": "user", "content": "I avoid caffeine after 4 PM because it affects my sleep."},
{"role": "assistant", "content": "Got it — you avoid caffeine post-4 PM to protect your sleep."},
] ]
store_user_preferences(messages) store_user_preferences(messages)

View File

@@ -1,9 +1,11 @@
import asyncio import asyncio
import warnings import warnings
from google.adk.agents import Agent from google.adk.agents import Agent
from google.adk.sessions import InMemorySessionService
from google.adk.runners import Runner from google.adk.runners import Runner
from google.adk.sessions import InMemorySessionService
from google.genai import types from google.genai import types
from mem0 import MemoryClient from mem0 import MemoryClient
warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=DeprecationWarning)
@@ -19,14 +21,14 @@ def save_patient_info(information: str) -> dict:
print(f"Storing patient information: {information[:30]}...") print(f"Storing patient information: {information[:30]}...")
# Get user_id from session state or use default # Get user_id from session state or use default
user_id = getattr(save_patient_info, 'user_id', 'default_user') user_id = getattr(save_patient_info, "user_id", "default_user")
# Store in Mem0 # Store in Mem0
response = mem0_client.add( mem0_client.add(
[{"role": "user", "content": information}], [{"role": "user", "content": information}],
user_id=user_id, user_id=user_id,
run_id="healthcare_session", run_id="healthcare_session",
metadata={"type": "patient_information"} metadata={"type": "patient_information"},
) )
return {"status": "success", "message": "Information saved"} return {"status": "success", "message": "Information saved"}
@@ -37,7 +39,7 @@ def retrieve_patient_info(query: str) -> str:
print(f"Searching for patient information: {query}") print(f"Searching for patient information: {query}")
# Get user_id from session state or use default # Get user_id from session state or use default
user_id = getattr(retrieve_patient_info, 'user_id', 'default_user') user_id = getattr(retrieve_patient_info, "user_id", "default_user")
# Search Mem0 # Search Mem0
results = mem0_client.search( results = mem0_client.search(
@@ -45,7 +47,7 @@ def retrieve_patient_info(query: str) -> str:
user_id=user_id, user_id=user_id,
run_id="healthcare_session", run_id="healthcare_session",
limit=5, limit=5,
threshold=0.7 # Higher threshold for more relevant results threshold=0.7, # Higher threshold for more relevant results
) )
if not results: if not results:
@@ -65,7 +67,7 @@ def schedule_appointment(date: str, time: str, reason: str) -> dict:
"status": "success", "status": "success",
"appointment_id": appointment_id, "appointment_id": appointment_id,
"confirmation": f"Appointment scheduled for {date} at {time} for {reason}", "confirmation": f"Appointment scheduled for {date} at {time} for {reason}",
"message": "Please arrive 15 minutes early to complete paperwork." "message": "Please arrive 15 minutes early to complete paperwork.",
} }
@@ -89,7 +91,7 @@ IMPORTANT GUIDELINES:
- For serious symptoms, always recommend consulting a healthcare professional. - For serious symptoms, always recommend consulting a healthcare professional.
- Keep all patient information confidential. - Keep all patient information confidential.
""", """,
tools=[save_patient_info, retrieve_patient_info, schedule_appointment] tools=[save_patient_info, retrieve_patient_info, schedule_appointment],
) )
# Set Up Session and Runner # Set Up Session and Runner
@@ -101,18 +103,10 @@ USER_ID = "Alex"
SESSION_ID = "session_001" SESSION_ID = "session_001"
# Create a session # Create a session
session = session_service.create_session( session = session_service.create_session(app_name=APP_NAME, user_id=USER_ID, session_id=SESSION_ID)
app_name=APP_NAME,
user_id=USER_ID,
session_id=SESSION_ID
)
# Create the runner # Create the runner
runner = Runner( runner = Runner(agent=healthcare_agent, app_name=APP_NAME, session_service=session_service)
agent=healthcare_agent,
app_name=APP_NAME,
session_service=session_service
)
# Interact with the Healthcare Assistant # Interact with the Healthcare Assistant
@@ -121,21 +115,14 @@ async def call_agent_async(query, runner, user_id, session_id):
print(f"\n>>> Patient: {query}") print(f"\n>>> Patient: {query}")
# Format the user's message # Format the user's message
content = types.Content( content = types.Content(role="user", parts=[types.Part(text=query)])
role='user',
parts=[types.Part(text=query)]
)
# Set user_id for tools to access # Set user_id for tools to access
save_patient_info.user_id = user_id save_patient_info.user_id = user_id
retrieve_patient_info.user_id = user_id retrieve_patient_info.user_id = user_id
# Run the agent # Run the agent
async for event in runner.run_async( async for event in runner.run_async(user_id=user_id, session_id=session_id, new_message=content):
user_id=user_id,
session_id=session_id,
new_message=content
):
if event.is_final_response(): if event.is_final_response():
if event.content and event.content.parts: if event.content and event.content.parts:
response = event.content.parts[0].text response = event.content.parts[0].text
@@ -152,7 +139,7 @@ async def run_conversation():
"Hi, I'm Alex. I've been having headaches for the past week, and I have a penicillin allergy.", "Hi, I'm Alex. I've been having headaches for the past week, and I have a penicillin allergy.",
runner=runner, runner=runner,
user_id=USER_ID, user_id=USER_ID,
session_id=SESSION_ID session_id=SESSION_ID,
) )
# Request for health information # Request for health information
@@ -160,7 +147,7 @@ async def run_conversation():
"Can you tell me more about what might be causing my headaches?", "Can you tell me more about what might be causing my headaches?",
runner=runner, runner=runner,
user_id=USER_ID, user_id=USER_ID,
session_id=SESSION_ID session_id=SESSION_ID,
) )
# Schedule an appointment # Schedule an appointment
@@ -168,15 +155,12 @@ async def run_conversation():
"I think I should see a doctor. Can you help me schedule an appointment for next Monday at 2pm?", "I think I should see a doctor. Can you help me schedule an appointment for next Monday at 2pm?",
runner=runner, runner=runner,
user_id=USER_ID, user_id=USER_ID,
session_id=SESSION_ID session_id=SESSION_ID,
) )
# Test memory - should remember patient name, symptoms, and allergy # Test memory - should remember patient name, symptoms, and allergy
await call_agent_async( await call_agent_async(
"What medications should I avoid for my headaches?", "What medications should I avoid for my headaches?", runner=runner, user_id=USER_ID, session_id=SESSION_ID
runner=runner,
user_id=USER_ID,
session_id=SESSION_ID
) )
@@ -191,37 +175,28 @@ async def interactive_mode():
session_id = f"session_{hash(patient_id) % 1000:03d}" session_id = f"session_{hash(patient_id) % 1000:03d}"
# Create session for this user # Create session for this user
session = session_service.create_session( session_service.create_session(app_name=APP_NAME, user_id=patient_id, session_id=session_id)
app_name=APP_NAME,
user_id=patient_id,
session_id=session_id
)
print(f"\nStarting conversation with patient ID: {patient_id}") print(f"\nStarting conversation with patient ID: {patient_id}")
print("Type your message and press Enter.") print("Type your message and press Enter.")
while True: while True:
user_input = input("\n>>> Patient: ").strip() user_input = input("\n>>> Patient: ").strip()
if user_input.lower() in ['exit', 'quit', 'bye']: if user_input.lower() in ["exit", "quit", "bye"]:
print("Ending conversation. Thank you!") print("Ending conversation. Thank you!")
break break
await call_agent_async( await call_agent_async(user_input, runner=runner, user_id=patient_id, session_id=session_id)
user_input,
runner=runner,
user_id=patient_id,
session_id=session_id
)
# Main execution # Main execution
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
parser = argparse.ArgumentParser(description='Healthcare Assistant with Memory') parser = argparse.ArgumentParser(description="Healthcare Assistant with Memory")
parser.add_argument('--demo', action='store_true', help='Run the demo conversation') parser.add_argument("--demo", action="store_true", help="Run the demo conversation")
parser.add_argument('--interactive', action='store_true', help='Run in interactive mode') parser.add_argument("--interactive", action="store_true", help="Run in interactive mode")
parser.add_argument('--patient-id', type=str, default=USER_ID, help='Patient ID for the conversation') parser.add_argument("--patient-id", type=str, default=USER_ID, help="Patient ID for the conversation")
args = parser.parse_args() args = parser.parse_args()
if args.demo: if args.demo:
@@ -231,5 +206,3 @@ if __name__ == "__main__":
else: else:
# Default to demo mode if no arguments provided # Default to demo mode if no arguments provided
asyncio.run(run_conversation()) asyncio.run(run_conversation())

View File

@@ -16,26 +16,21 @@ from mem0 import Memory
# Configure Mem0 with Grok 3 and Qdrant # Configure Mem0 with Grok 3 and Qdrant
config = { config = {
"vector_store": { "vector_store": {"provider": "qdrant", "config": {"embedding_model_dims": 384}},
"provider": "qdrant",
"config": {
"embedding_model_dims": 384
}
},
"llm": { "llm": {
"provider": "xai", "provider": "xai",
"config": { "config": {
"model": "grok-3-beta", "model": "grok-3-beta",
"temperature": 0.1, "temperature": 0.1,
"max_tokens": 2000, "max_tokens": 2000,
} },
}, },
"embedder": { "embedder": {
"provider": "huggingface", "provider": "huggingface",
"config": { "config": {
"model": "all-MiniLM-L6-v2" # open embedding model "model": "all-MiniLM-L6-v2" # open embedding model
} },
} },
} }
# Instantiate memory layer # Instantiate memory layer
@@ -57,20 +52,14 @@ def recommend_movie_with_memory(user_id: str, user_query: str):
prompt += f"\nPreviously, the user mentioned: {past_memories}" prompt += f"\nPreviously, the user mentioned: {past_memories}"
# Generate movie recommendation using Grok 3 # Generate movie recommendation using Grok 3
response = grok_client.chat.completions.create( response = grok_client.chat.completions.create(model="grok-3-beta", messages=[{"role": "user", "content": prompt}])
model="grok-3-beta",
messages=[
{"role": "user", "content": prompt}
]
)
recommendation = response.choices[0].message.content recommendation = response.choices[0].message.content
# Store conversation in memory # Store conversation in memory
memory.add( memory.add(
[{"role": "user", "content": user_query}, [{"role": "user", "content": user_query}, {"role": "assistant", "content": recommendation}],
{"role": "assistant", "content": recommendation}],
user_id=user_id, user_id=user_id,
metadata={"category": "movie"} metadata={"category": "movie"},
) )
return recommendation return recommendation
@@ -81,10 +70,11 @@ if __name__ == "__main__":
user_id = "arshi" user_id = "arshi"
recommend_movie_with_memory(user_id, "I'm looking for a movie to watch tonight. Any suggestions?") recommend_movie_with_memory(user_id, "I'm looking for a movie to watch tonight. Any suggestions?")
# OUTPUT: You have watched Intersteller last weekend and you don't like horror movies, maybe you can watch "Purple Hearts" today. # OUTPUT: You have watched Intersteller last weekend and you don't like horror movies, maybe you can watch "Purple Hearts" today.
recommend_movie_with_memory(user_id, "Can we skip the tearjerkers? I really enjoyed Notting Hill and Crazy Rich Asians.") recommend_movie_with_memory(
user_id, "Can we skip the tearjerkers? I really enjoyed Notting Hill and Crazy Rich Asians."
)
# OUTPUT: Got it — no sad endings! You might enjoy "The Proposal" or "Love, Rosie". Theyre both light-hearted romcoms with happy vibes. # OUTPUT: Got it — no sad endings! You might enjoy "The Proposal" or "Love, Rosie". Theyre both light-hearted romcoms with happy vibes.
recommend_movie_with_memory(user_id, "Any light-hearted movie I can watch after work today?") recommend_movie_with_memory(user_id, "Any light-hearted movie I can watch after work today?")
# OUTPUT: Since you liked Crazy Rich Asians and The Proposal, how about "The Intern" or "Isnt It Romantic"? Both are upbeat, funny, and perfect for relaxing. # OUTPUT: Since you liked Crazy Rich Asians and The Proposal, how about "The Intern" or "Isnt It Romantic"? Both are upbeat, funny, and perfect for relaxing.
recommend_movie_with_memory(user_id, "Ive already watched The Intern. Something new maybe?") recommend_movie_with_memory(user_id, "Ive already watched The Intern. Something new maybe?")
# OUTPUT: No problem! Try "Your Place or Mine" - romcoms that match your taste and are tear-free! # OUTPUT: No problem! Try "Your Place or Mine" - romcoms that match your taste and are tear-free!

View File

@@ -23,8 +23,8 @@ agent = Agent(
name="Personal Agent", name="Personal Agent",
model=OpenAIChat(id="gpt-4o"), model=OpenAIChat(id="gpt-4o"),
description="You are a helpful personal agent that helps me with day to day activities." description="You are a helpful personal agent that helps me with day to day activities."
"You can process both text and images.", "You can process both text and images.",
markdown=True markdown=True,
) )
@@ -35,24 +35,16 @@ def chat_user(user_input: str = None, user_id: str = "user_123", image_path: str
base64_image = base64.b64encode(image_file.read()).decode("utf-8") base64_image = base64.b64encode(image_file.read()).decode("utf-8")
# First: the text message # First: the text message
text_msg = { text_msg = {"role": "user", "content": user_input}
"role": "user",
"content": user_input
}
# Second: the image message # Second: the image message
image_msg = { image_msg = {
"role": "user", "role": "user",
"content": { "content": {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}},
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
} }
# Send both as separate message objects # Send both as separate message objects
client.add([text_msg, image_msg], user_id=user_id, output_format='v1.1') client.add([text_msg, image_msg], user_id=user_id, output_format="v1.1")
print("✅ Image uploaded and stored in memory.") print("✅ Image uploaded and stored in memory.")
if user_input: if user_input:
@@ -92,10 +84,13 @@ print(chat_user("When is my test?", user_id=user_id))
# OUTPUT: Your pilot's test is on your birthday, which is in five days. You're turning 25! # OUTPUT: Your pilot's test is on your birthday, which is in five days. You're turning 25!
# Good luck with your preparations, and remember to take some time to relax amidst the studying. # Good luck with your preparations, and remember to take some time to relax amidst the studying.
print(chat_user("This is the picture of what I brought with me in the trip to Bahamas", print(
image_path="travel_items.jpeg", # this will be added to Mem0 memory chat_user(
user_id=user_id)) "This is the picture of what I brought with me in the trip to Bahamas",
print(chat_user("hey can you quickly tell me if brought my sunglasses to my trip, not able to find", image_path="travel_items.jpeg", # this will be added to Mem0 memory
user_id=user_id)) user_id=user_id,
)
)
print(chat_user("hey can you quickly tell me if brought my sunglasses to my trip, not able to find", user_id=user_id))
# OUTPUT: Yes, you did bring your sunglasses on your trip to the Bahamas along with your laptop, face masks and other items.. # OUTPUT: Yes, you did bring your sunglasses on your trip to the Bahamas along with your laptop, face masks and other items..
# Since you can't find them now, perhaps check the pockets of jackets you wore or in your luggage compartments. # Since you can't find them now, perhaps check the pockets of jackets you wore or in your luggage compartments.

View File

@@ -7,6 +7,7 @@ In order to run this file, you need to set up your Mem0 API at Mem0 platform and
export OPENAI_API_KEY="your_openai_api_key" export OPENAI_API_KEY="your_openai_api_key"
export MEM0_API_KEY="your_mem0_api_key" export MEM0_API_KEY="your_mem0_api_key"
""" """
import asyncio import asyncio
from agents import Agent, Runner from agents import Agent, Runner
@@ -23,25 +24,19 @@ study_agent = Agent(
- Identify topics the user has struggled with (e.g., "I'm confused", "this is hard") - Identify topics the user has struggled with (e.g., "I'm confused", "this is hard")
- Help with spaced repetition by suggesting topics to revisit based on last review time - Help with spaced repetition by suggesting topics to revisit based on last review time
- Personalize answers using stored memories - Personalize answers using stored memories
- Summarize PDFs or notes the user uploads""") - Summarize PDFs or notes the user uploads""",
)
# Upload and store PDF to Mem0 # Upload and store PDF to Mem0
def upload_pdf(pdf_url: str, user_id: str): def upload_pdf(pdf_url: str, user_id: str):
pdf_message = { pdf_message = {"role": "user", "content": {"type": "pdf_url", "pdf_url": {"url": pdf_url}}}
"role": "user",
"content": {
"type": "pdf_url",
"pdf_url": {"url": pdf_url}
}
}
client.add([pdf_message], user_id=user_id) client.add([pdf_message], user_id=user_id)
print("✅ PDF uploaded and processed into memory.") print("✅ PDF uploaded and processed into memory.")
# Main interaction loop with your personal study buddy # Main interaction loop with your personal study buddy
async def study_buddy(user_id: str, topic: str, user_input: str): async def study_buddy(user_id: str, topic: str, user_input: str):
memories = client.search(f"{topic}", user_id=user_id) memories = client.search(f"{topic}", user_id=user_id)
memory_context = "n".join(f"- {m['memory']}" for m in memories) memory_context = "n".join(f"- {m['memory']}" for m in memories)
@@ -56,9 +51,11 @@ Now respond to the user's new question or comment:
result = await Runner.run(study_agent, prompt) result = await Runner.run(study_agent, prompt)
response = result.final_output response = result.final_output
client.add([ client.add(
{"role": "user", "content": f'''Topic: {topic}nUser: {user_input}nnStudy Assistant: {response}'''} [{"role": "user", "content": f"""Topic: {topic}nUser: {user_input}nnStudy Assistant: {response}"""}],
], user_id=user_id, metadata={"topic": topic}) user_id=user_id,
metadata={"topic": topic},
)
return response return response
@@ -78,7 +75,12 @@ async def main():
# Demonstrate spaced repetition prompting # Demonstrate spaced repetition prompting
topic = "Momentum Conservation" topic = "Momentum Conservation"
print(await study_buddy(user_id, topic, "I think we covered this last week. Is it time to review momentum conservation again?")) print(
await study_buddy(
user_id, topic, "I think we covered this last week. Is it time to review momentum conservation again?"
)
)
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -57,7 +57,7 @@ def initialize_memory():
}, },
{ {
"role": "user", "role": "user",
"content": "I prefer brief and concise responses without unnecessary explanations. I get frustrated when assistants are too wordy or repeat information I already know." "content": "I prefer brief and concise responses without unnecessary explanations. I get frustrated when assistants are too wordy or repeat information I already know.",
}, },
{ {
"role": "assistant", "role": "assistant",
@@ -65,7 +65,7 @@ def initialize_memory():
}, },
{ {
"role": "user", "role": "user",
"content": "I like to listen to jazz music when I'm working, especially artists like Miles Davis and John Coltrane. I find it helps me focus and be more productive." "content": "I like to listen to jazz music when I'm working, especially artists like Miles Davis and John Coltrane. I find it helps me focus and be more productive.",
}, },
{ {
"role": "assistant", "role": "assistant",
@@ -73,7 +73,7 @@ def initialize_memory():
}, },
{ {
"role": "user", "role": "user",
"content": "I usually wake up at 7 AM and prefer reminders for meetings 30 minutes in advance. My most productive hours are between 9 AM and noon, so I try to schedule important tasks during that time." "content": "I usually wake up at 7 AM and prefer reminders for meetings 30 minutes in advance. My most productive hours are between 9 AM and noon, so I try to schedule important tasks during that time.",
}, },
{ {
"role": "assistant", "role": "assistant",
@@ -81,7 +81,7 @@ def initialize_memory():
}, },
{ {
"role": "user", "role": "user",
"content": "My favorite color is navy blue, and I prefer dark mode in all my apps. I'm allergic to peanuts, so please remind me to check ingredients when I ask about recipes or restaurants." "content": "My favorite color is navy blue, and I prefer dark mode in all my apps. I'm allergic to peanuts, so please remind me to check ingredients when I ask about recipes or restaurants.",
}, },
{ {
"role": "assistant", "role": "assistant",
@@ -89,7 +89,7 @@ def initialize_memory():
}, },
{ {
"role": "user", "role": "user",
"content": "My partner's name is Jamie, and we have a golden retriever named Max who is 3 years old. My parents live in Chicago, and I try to visit them once every two months." "content": "My partner's name is Jamie, and we have a golden retriever named Max who is 3 years old. My parents live in Chicago, and I try to visit them once every two months.",
}, },
{ {
"role": "assistant", "role": "assistant",
@@ -135,11 +135,11 @@ def record_audio(filename="input.wav", record_seconds=5):
stream.close() stream.close()
p.terminate() p.terminate()
with wave.open(filename, 'wb') as wf: with wave.open(filename, "wb") as wf:
wf.setnchannels(channels) wf.setnchannels(channels)
wf.setsampwidth(p.get_sample_size(fmt)) wf.setsampwidth(p.get_sample_size(fmt))
wf.setframerate(rate) wf.setframerate(rate)
wf.writeframes(b''.join(frames)) wf.writeframes(b"".join(frames))
# ------------------ STT USING WHISPER ------------------ # ------------------ STT USING WHISPER ------------------
@@ -147,10 +147,7 @@ def transcribe_whisper(audio_path):
print("🔎 Transcribing with Whisper...") print("🔎 Transcribing with Whisper...")
try: try:
with open(audio_path, "rb") as audio_file: with open(audio_path, "rb") as audio_file:
transcript = openai_client.audio.transcriptions.create( transcript = openai_client.audio.transcriptions.create(model="whisper-1", file=audio_file)
model="whisper-1",
file=audio_file
)
print(f"🗣️ You said: {transcript.text}") print(f"🗣️ You said: {transcript.text}")
return transcript.text return transcript.text
except Exception as e: except Exception as e:
@@ -165,9 +162,7 @@ def get_agent_response(user_input):
try: try:
task = Task( task = Task(
description=f"Respond to: {user_input}", description=f"Respond to: {user_input}", expected_output="A short and relevant reply.", agent=voice_agent
expected_output="A short and relevant reply.",
agent=voice_agent
) )
crew = Crew( crew = Crew(
agents=[voice_agent], agents=[voice_agent],
@@ -175,22 +170,19 @@ def get_agent_response(user_input):
process=Process.sequential, process=Process.sequential,
verbose=True, verbose=True,
memory=True, memory=True,
memory_config={ memory_config={"provider": "mem0", "config": {"user_id": USER_ID}},
"provider": "mem0",
"config": {"user_id": USER_ID}
}
) )
result = crew.kickoff() result = crew.kickoff()
# Extract the text response from the complex result object # Extract the text response from the complex result object
if hasattr(result, 'raw'): if hasattr(result, "raw"):
return result.raw return result.raw
elif isinstance(result, dict) and 'raw' in result: elif isinstance(result, dict) and "raw" in result:
return result['raw'] return result["raw"]
elif isinstance(result, dict) and 'tasks_output' in result: elif isinstance(result, dict) and "tasks_output" in result:
outputs = result['tasks_output'] outputs = result["tasks_output"]
if outputs and isinstance(outputs, list) and len(outputs) > 0: if outputs and isinstance(outputs, list) and len(outputs) > 0:
return outputs[0].get('raw', str(result)) return outputs[0].get("raw", str(result))
# Fallback to string representation if we can't extract the raw response # Fallback to string representation if we can't extract the raw response
return str(result) return str(result)
@@ -204,10 +196,7 @@ def get_agent_response(user_input):
def speak_response(text): def speak_response(text):
print(f"🤖 Agent: {text}") print(f"🤖 Agent: {text}")
audio = tts_client.text_to_speech.convert( audio = tts_client.text_to_speech.convert(
text=text, text=text, voice_id="JBFqnCBsd6RMkjVDRZzb", model_id="eleven_multilingual_v2", output_format="mp3_44100_128"
voice_id="JBFqnCBsd6RMkjVDRZzb",
model_id="eleven_multilingual_v2",
output_format="mp3_44100_128"
) )
play(audio) play(audio)
@@ -220,7 +209,7 @@ def run_voice_agent():
record_audio(tmp_audio.name) record_audio(tmp_audio.name)
try: try:
user_text = transcribe_whisper(tmp_audio.name) user_text = transcribe_whisper(tmp_audio.name)
if user_text.lower() in ['exit', 'quit', 'stop']: if user_text.lower() in ["exit", "quit", "stop"]:
print("👋 Exiting.") print("👋 Exiting.")
break break
response = get_agent_response(user_text) response = get_agent_response(user_text)

View File

@@ -95,10 +95,7 @@ class MemoryClient:
self.client = client self.client = client
# Ensure the client has the correct base_url and headers # Ensure the client has the correct base_url and headers
self.client.base_url = httpx.URL(self.host) self.client.base_url = httpx.URL(self.host)
self.client.headers.update({ self.client.headers.update({"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id})
"Authorization": f"Token {self.api_key}",
"Mem0-User-ID": self.user_id
})
else: else:
self.client = httpx.Client( self.client = httpx.Client(
base_url=self.host, base_url=self.host,
@@ -237,7 +234,9 @@ class MemoryClient:
response.raise_for_status() response.raise_for_status()
if "metadata" in kwargs: if "metadata" in kwargs:
del kwargs["metadata"] del kwargs["metadata"]
capture_client_event("client.search", self, {"api_version": version, "keys": list(kwargs.keys()), "sync_type": "sync"}) capture_client_event(
"client.search", self, {"api_version": version, "keys": list(kwargs.keys()), "sync_type": "sync"}
)
return response.json() return response.json()
@api_error_handler @api_error_handler
@@ -357,10 +356,7 @@ class MemoryClient:
else: else:
entities = self.users() entities = self.users()
# Filter entities based on provided IDs using list comprehension # Filter entities based on provided IDs using list comprehension
to_delete = [ to_delete = [{"type": entity["type"], "name": entity["name"]} for entity in entities["results"]]
{"type": entity["type"], "name": entity["name"]}
for entity in entities["results"]
]
params = self._prepare_params() params = self._prepare_params()
@@ -373,7 +369,9 @@ class MemoryClient:
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event(
"client.delete_users", self, {"user_id": user_id, "agent_id": agent_id, "app_id": app_id, "run_id": run_id, "sync_type": "sync"} "client.delete_users",
self,
{"user_id": user_id, "agent_id": agent_id, "app_id": app_id, "run_id": run_id, "sync_type": "sync"},
) )
return { return {
"message": "Entity deleted successfully." "message": "Entity deleted successfully."
@@ -454,7 +452,9 @@ class MemoryClient:
""" """
response = self.client.post("/v1/exports/", json={"schema": schema, **self._prepare_params(kwargs)}) response = self.client.post("/v1/exports/", json={"schema": schema, **self._prepare_params(kwargs)})
response.raise_for_status() response.raise_for_status()
capture_client_event("client.create_memory_export", self, {"schema": schema, "keys": list(kwargs.keys()), "sync_type": "sync"}) capture_client_event(
"client.create_memory_export", self, {"schema": schema, "keys": list(kwargs.keys()), "sync_type": "sync"}
)
return response.json() return response.json()
@api_error_handler @api_error_handler
@@ -527,7 +527,11 @@ class MemoryClient:
) )
payload = self._prepare_params( payload = self._prepare_params(
{"custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria} {
"custom_instructions": custom_instructions,
"custom_categories": custom_categories,
"retrieval_criteria": retrieval_criteria,
}
) )
response = self.client.patch( response = self.client.patch(
f"/api/v1/orgs/organizations/{self.org_id}/projects/{self.project_id}/", f"/api/v1/orgs/organizations/{self.org_id}/projects/{self.project_id}/",
@@ -537,7 +541,12 @@ class MemoryClient:
capture_client_event( capture_client_event(
"client.update_project", "client.update_project",
self, self,
{"custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria, "sync_type": "sync"}, {
"custom_instructions": custom_instructions,
"custom_categories": custom_categories,
"retrieval_criteria": retrieval_criteria,
"sync_type": "sync",
},
) )
return response.json() return response.json()
@@ -750,10 +759,7 @@ class AsyncMemoryClient:
self.async_client = client self.async_client = client
# Ensure the client has the correct base_url and headers # Ensure the client has the correct base_url and headers
self.async_client.base_url = httpx.URL(self.host) self.async_client.base_url = httpx.URL(self.host)
self.async_client.headers.update({ self.async_client.headers.update({"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id})
"Authorization": f"Token {self.api_key}",
"Mem0-User-ID": self.user_id
})
else: else:
self.async_client = httpx.AsyncClient( self.async_client = httpx.AsyncClient(
base_url=self.host, base_url=self.host,
@@ -768,7 +774,11 @@ class AsyncMemoryClient:
"""Validate the API key by making a test request.""" """Validate the API key by making a test request."""
try: try:
params = self._prepare_params() params = self._prepare_params()
response = requests.get(f"{self.host}/v1/ping/", headers={"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id}, params=params) response = requests.get(
f"{self.host}/v1/ping/",
headers={"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id},
params=params,
)
data = response.json() data = response.json()
response.raise_for_status() response.raise_for_status()
@@ -973,10 +983,7 @@ class AsyncMemoryClient:
else: else:
entities = await self.users() entities = await self.users()
# Filter entities based on provided IDs using list comprehension # Filter entities based on provided IDs using list comprehension
to_delete = [ to_delete = [{"type": entity["type"], "name": entity["name"]} for entity in entities["results"]]
{"type": entity["type"], "name": entity["name"]}
for entity in entities["results"]
]
params = self._prepare_params() params = self._prepare_params()
@@ -988,7 +995,11 @@ class AsyncMemoryClient:
response = await self.async_client.delete(f"/v2/entities/{entity['type']}/{entity['name']}/", params=params) response = await self.async_client.delete(f"/v2/entities/{entity['type']}/{entity['name']}/", params=params)
response.raise_for_status() response.raise_for_status()
capture_client_event("client.delete_users", self, {"user_id": user_id, "agent_id": agent_id, "app_id": app_id, "run_id": run_id, "sync_type": "async"}) capture_client_event(
"client.delete_users",
self,
{"user_id": user_id, "agent_id": agent_id, "app_id": app_id, "run_id": run_id, "sync_type": "async"},
)
return { return {
"message": "Entity deleted successfully." "message": "Entity deleted successfully."
if (user_id or agent_id or app_id or run_id) if (user_id or agent_id or app_id or run_id)
@@ -1091,8 +1102,10 @@ class AsyncMemoryClient:
@api_error_handler @api_error_handler
async def update_project( async def update_project(
self, custom_instructions: Optional[str] = None, custom_categories: Optional[List[str]] = None, self,
retrieval_criteria: Optional[List[Dict[str, Any]]] = None custom_instructions: Optional[str] = None,
custom_categories: Optional[List[str]] = None,
retrieval_criteria: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if not (self.org_id and self.project_id): if not (self.org_id and self.project_id):
raise ValueError("org_id and project_id must be set to update instructions or categories") raise ValueError("org_id and project_id must be set to update instructions or categories")
@@ -1103,7 +1116,11 @@ class AsyncMemoryClient:
) )
payload = self._prepare_params( payload = self._prepare_params(
{"custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria} {
"custom_instructions": custom_instructions,
"custom_categories": custom_categories,
"retrieval_criteria": retrieval_criteria,
}
) )
response = await self.async_client.patch( response = await self.async_client.patch(
f"/api/v1/orgs/organizations/{self.org_id}/projects/{self.project_id}/", f"/api/v1/orgs/organizations/{self.org_id}/projects/{self.project_id}/",
@@ -1113,7 +1130,12 @@ class AsyncMemoryClient:
capture_client_event( capture_client_event(
"client.update_project", "client.update_project",
self, self,
{"custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria, "sync_type": "async"}, {
"custom_instructions": custom_instructions,
"custom_categories": custom_categories,
"retrieval_criteria": retrieval_criteria,
"sync_type": "async",
},
) )
return response.json() return response.json()
@@ -1174,4 +1196,3 @@ class AsyncMemoryClient:
response.raise_for_status() response.raise_for_status()
capture_client_event("client.feedback", self, data, {"sync_type": "async"}) capture_client_event("client.feedback", self, data, {"sync_type": "async"})
return response.json() return response.json()

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Union, Type from typing import Any, Dict, Optional, Type, Union
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
@@ -7,33 +7,17 @@ class OpenSearchConfig(BaseModel):
collection_name: str = Field("mem0", description="Name of the index") collection_name: str = Field("mem0", description="Name of the index")
host: str = Field("localhost", description="OpenSearch host") host: str = Field("localhost", description="OpenSearch host")
port: int = Field(9200, description="OpenSearch port") port: int = Field(9200, description="OpenSearch port")
user: Optional[str] = Field( user: Optional[str] = Field(None, description="Username for authentication")
None, description="Username for authentication" password: Optional[str] = Field(None, description="Password for authentication")
) api_key: Optional[str] = Field(None, description="API key for authentication (if applicable)")
password: Optional[str] = Field( embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
None, description="Password for authentication" verify_certs: bool = Field(False, description="Verify SSL certificates (default False for OpenSearch)")
) use_ssl: bool = Field(False, description="Use SSL for connection (default False for OpenSearch)")
api_key: Optional[str] = Field( http_auth: Optional[object] = Field(None, description="HTTP authentication method / AWS SigV4")
None, description="API key for authentication (if applicable)"
)
embedding_model_dims: int = Field(
1536, description="Dimension of the embedding vector"
)
verify_certs: bool = Field(
False, description="Verify SSL certificates (default False for OpenSearch)"
)
use_ssl: bool = Field(
False, description="Use SSL for connection (default False for OpenSearch)"
)
http_auth: Optional[object] = Field(
None, description="HTTP authentication method / AWS SigV4"
)
connection_class: Optional[Union[str, Type]] = Field( connection_class: Optional[Union[str, Type]] = Field(
"RequestsHttpConnection", description="Connection class for OpenSearch" "RequestsHttpConnection", description="Connection class for OpenSearch"
) )
pool_maxsize: int = Field( pool_maxsize: int = Field(20, description="Maximum number of connections in the pool")
20, description="Maximum number of connections in the pool"
)
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@@ -41,7 +25,7 @@ class OpenSearchConfig(BaseModel):
# Check if host is provided # Check if host is provided
if not values.get("host"): if not values.get("host"):
raise ValueError("Host must be provided for OpenSearch") raise ValueError("Host must be provided for OpenSearch")
return values return values
@model_validator(mode="before") @model_validator(mode="before")
@@ -52,7 +36,6 @@ class OpenSearchConfig(BaseModel):
extra_fields = input_fields - allowed_fields extra_fields = input_fields - allowed_fields
if extra_fields: if extra_fields:
raise ValueError( raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. " f"Extra fields not allowed: {', '.join(extra_fields)}. " f"Allowed fields: {', '.join(allowed_fields)}"
f"Allowed fields: {', '.join(allowed_fields)}"
) )
return values return values

View File

@@ -23,12 +23,12 @@ class AWSBedrockEmbedding(EmbeddingBase):
super().__init__(config) super().__init__(config)
self.config.model = self.config.model or "amazon.titan-embed-text-v1" self.config.model = self.config.model or "amazon.titan-embed-text-v1"
# Get AWS config from environment variables or use defaults # Get AWS config from environment variables or use defaults
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "") aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "") aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
aws_region = os.environ.get("AWS_REGION", "us-west-2") aws_region = os.environ.get("AWS_REGION", "us-west-2")
# Check if AWS config is provided in the config # Check if AWS config is provided in the config
if hasattr(self.config, "aws_access_key_id"): if hasattr(self.config, "aws_access_key_id"):
aws_access_key = self.config.aws_access_key_id aws_access_key = self.config.aws_access_key_id
@@ -36,7 +36,7 @@ class AWSBedrockEmbedding(EmbeddingBase):
aws_secret_key = self.config.aws_secret_access_key aws_secret_key = self.config.aws_secret_access_key
if hasattr(self.config, "aws_region"): if hasattr(self.config, "aws_region"):
aws_region = self.config.aws_region aws_region = self.config.aws_region
self.client = boto3.client( self.client = boto3.client(
"bedrock-runtime", "bedrock-runtime",
region_name=aws_region, region_name=aws_region,

View File

@@ -11,6 +11,7 @@ logging.getLogger("transformers").setLevel(logging.WARNING)
logging.getLogger("sentence_transformers").setLevel(logging.WARNING) logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
logging.getLogger("huggingface_hub").setLevel(logging.WARNING) logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
class HuggingFaceEmbedding(EmbeddingBase): class HuggingFaceEmbedding(EmbeddingBase):
def __init__(self, config: Optional[BaseEmbedderConfig] = None): def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config) super().__init__(config)

View File

@@ -22,7 +22,8 @@ class Neo4jConfig(BaseModel):
if not url or not username or not password: if not url or not username or not password:
raise ValueError("Please provide 'url', 'username' and 'password'.") raise ValueError("Please provide 'url', 'username' and 'password'.")
return values return values
class MemgraphConfig(BaseModel): class MemgraphConfig(BaseModel):
url: Optional[str] = Field(None, description="Host address for the graph database") url: Optional[str] = Field(None, description="Host address for the graph database")
username: Optional[str] = Field(None, description="Username for the graph database") username: Optional[str] = Field(None, description="Username for the graph database")

View File

@@ -20,18 +20,19 @@ def extract_provider(model: str) -> str:
return provider return provider
raise ValueError(f"Unknown provider in model: {model}") raise ValueError(f"Unknown provider in model: {model}")
class AWSBedrockLLM(LLMBase): class AWSBedrockLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None): def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config) super().__init__(config)
if not self.config.model: if not self.config.model:
self.config.model = "anthropic.claude-3-5-sonnet-20240620-v1:0" self.config.model = "anthropic.claude-3-5-sonnet-20240620-v1:0"
# Get AWS config from environment variables or use defaults # Get AWS config from environment variables or use defaults
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "") aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "") aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
aws_region = os.environ.get("AWS_REGION", "us-west-2") aws_region = os.environ.get("AWS_REGION", "us-west-2")
# Check if AWS config is provided in the config # Check if AWS config is provided in the config
if hasattr(self.config, "aws_access_key_id"): if hasattr(self.config, "aws_access_key_id"):
aws_access_key = self.config.aws_access_key_id aws_access_key = self.config.aws_access_key_id
@@ -39,14 +40,14 @@ class AWSBedrockLLM(LLMBase):
aws_secret_key = self.config.aws_secret_access_key aws_secret_key = self.config.aws_secret_access_key
if hasattr(self.config, "aws_region"): if hasattr(self.config, "aws_region"):
aws_region = self.config.aws_region aws_region = self.config.aws_region
self.client = boto3.client( self.client = boto3.client(
"bedrock-runtime", "bedrock-runtime",
region_name=aws_region, region_name=aws_region,
aws_access_key_id=aws_access_key if aws_access_key else None, aws_access_key_id=aws_access_key if aws_access_key else None,
aws_secret_access_key=aws_secret_key if aws_secret_key else None, aws_secret_access_key=aws_secret_key if aws_secret_key else None,
) )
self.model_kwargs = { self.model_kwargs = {
"temperature": self.config.temperature, "temperature": self.config.temperature,
"max_tokens_to_sample": self.config.max_tokens, "max_tokens_to_sample": self.config.max_tokens,
@@ -145,7 +146,9 @@ class AWSBedrockLLM(LLMBase):
input_body = { input_body = {
"inputText": prompt, "inputText": prompt,
"textGenerationConfig": { "textGenerationConfig": {
"maxTokenCount": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000, "maxTokenCount": self.model_kwargs["max_tokens_to_sample"]
or self.model_kwargs["max_tokens"]
or 5000,
"topP": self.model_kwargs["top_p"] or 0.9, "topP": self.model_kwargs["top_p"] or 0.9,
"temperature": self.model_kwargs["temperature"] or 0.1, "temperature": self.model_kwargs["temperature"] or 0.1,
}, },
@@ -243,22 +246,15 @@ class AWSBedrockLLM(LLMBase):
body = json.dumps(input_body) body = json.dumps(input_body)
if provider == "anthropic" or provider == "deepseek": if provider == "anthropic" or provider == "deepseek":
input_body = { input_body = {
"messages": [ "messages": [{"role": "user", "content": [{"type": "text", "text": prompt}]}],
{
"role": "user",
"content": [{"type": "text", "text": prompt}]
}
],
"max_tokens": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000, "max_tokens": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000,
"temperature": self.model_kwargs["temperature"] or 0.1, "temperature": self.model_kwargs["temperature"] or 0.1,
"top_p": self.model_kwargs["top_p"] or 0.9, "top_p": self.model_kwargs["top_p"] or 0.9,
"anthropic_version": "bedrock-2023-05-31", "anthropic_version": "bedrock-2023-05-31",
} }
body = json.dumps(input_body) body = json.dumps(input_body)
response = self.client.invoke_model( response = self.client.invoke_model(
body=body, body=body,
@@ -272,6 +268,6 @@ class AWSBedrockLLM(LLMBase):
modelId=self.config.model, modelId=self.config.model,
accept="application/json", accept="application/json",
contentType="application/json", contentType="application/json",
) )
return self._parse_response(response, tools) return self._parse_response(response, tools)

View File

@@ -34,17 +34,17 @@ from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
def _build_filters_and_metadata( def _build_filters_and_metadata(
*, # Enforce keyword-only arguments *, # Enforce keyword-only arguments
user_id: Optional[str] = None, user_id: Optional[str] = None,
agent_id: Optional[str] = None, agent_id: Optional[str] = None,
run_id: Optional[str] = None, run_id: Optional[str] = None,
actor_id: Optional[str] = None, # For query-time filtering actor_id: Optional[str] = None, # For query-time filtering
input_metadata: Optional[Dict[str, Any]] = None, input_metadata: Optional[Dict[str, Any]] = None,
input_filters: Optional[Dict[str, Any]] = None, input_filters: Optional[Dict[str, Any]] = None,
) -> tuple[Dict[str, Any], Dict[str, Any]]: ) -> tuple[Dict[str, Any], Dict[str, Any]]:
""" """
Constructs metadata for storage and filters for querying based on session and actor identifiers. Constructs metadata for storage and filters for querying based on session and actor identifiers.
This helper ties every memory/query to exactly one session id (`user_id`, `agent_id`, or `run_id`) and optionally narrows queries to a specific `actor_id`. It returns two dicts: This helper ties every memory/query to exactly one session id (`user_id`, `agent_id`, or `run_id`) and optionally narrows queries to a specific `actor_id`. It returns two dicts:
@@ -78,10 +78,10 @@ def _build_filters_and_metadata(
- effective_query_filters (Dict[str, Any]): Filters for querying memories, - effective_query_filters (Dict[str, Any]): Filters for querying memories,
scoped to the determined session and potentially a resolved actor. scoped to the determined session and potentially a resolved actor.
""" """
base_metadata_template = deepcopy(input_metadata) if input_metadata else {} base_metadata_template = deepcopy(input_metadata) if input_metadata else {}
effective_query_filters = deepcopy(input_filters) if input_filters else {} effective_query_filters = deepcopy(input_filters) if input_filters else {}
# ---------- resolve session id (mandatory) ---------- # ---------- resolve session id (mandatory) ----------
session_key, session_val = None, None session_key, session_val = None, None
if user_id: if user_id:
@@ -90,20 +90,20 @@ def _build_filters_and_metadata(
session_key, session_val = "agent_id", agent_id session_key, session_val = "agent_id", agent_id
elif run_id: elif run_id:
session_key, session_val = "run_id", run_id session_key, session_val = "run_id", run_id
if session_key is None: if session_key is None:
raise ValueError("One of 'user_id', 'agent_id', or 'run_id' must be provided.") raise ValueError("One of 'user_id', 'agent_id', or 'run_id' must be provided.")
base_metadata_template[session_key] = session_val base_metadata_template[session_key] = session_val
effective_query_filters[session_key] = session_val effective_query_filters[session_key] = session_val
# ---------- optional actor filter ---------- # ---------- optional actor filter ----------
resolved_actor_id = actor_id or effective_query_filters.get("actor_id") resolved_actor_id = actor_id or effective_query_filters.get("actor_id")
if resolved_actor_id: if resolved_actor_id:
effective_query_filters["actor_id"] = resolved_actor_id effective_query_filters["actor_id"] = resolved_actor_id
return base_metadata_template, effective_query_filters return base_metadata_template, effective_query_filters
setup_config() setup_config()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -189,7 +189,7 @@ class Memory(MemoryBase):
): ):
""" """
Create a new memory. Create a new memory.
Adds new memories scoped to a single session id (e.g. `user_id`, `agent_id`, or `run_id`). One of those ids is required. Adds new memories scoped to a single session id (e.g. `user_id`, `agent_id`, or `run_id`). One of those ids is required.
Args: Args:
@@ -208,7 +208,7 @@ class Memory(MemoryBase):
creating procedural memories (typically requires 'agent_id'). Otherwise, memories creating procedural memories (typically requires 'agent_id'). Otherwise, memories
are treated as general conversational/factual memories.memory_type (str, optional): Type of memory to create. Defaults to None. By default, it creates the short term memories and long term (semantic and episodic) memories. Pass "procedural_memory" to create procedural memories. are treated as general conversational/factual memories.memory_type (str, optional): Type of memory to create. Defaults to None. By default, it creates the short term memories and long term (semantic and episodic) memories. Pass "procedural_memory" to create procedural memories.
prompt (str, optional): Prompt to use for the memory creation. Defaults to None. prompt (str, optional): Prompt to use for the memory creation. Defaults to None.
Returns: Returns:
dict: A dictionary containing the result of the memory addition operation, typically dict: A dictionary containing the result of the memory addition operation, typically
@@ -216,14 +216,14 @@ class Memory(MemoryBase):
and potentially "relations" if graph store is enabled. and potentially "relations" if graph store is enabled.
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "event": "ADD"}]}` Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "event": "ADD"}]}`
""" """
processed_metadata, effective_filters = _build_filters_and_metadata( processed_metadata, effective_filters = _build_filters_and_metadata(
user_id=user_id, user_id=user_id,
agent_id=agent_id, agent_id=agent_id,
run_id=run_id, run_id=run_id,
input_metadata=metadata, input_metadata=metadata,
) )
if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value: if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value:
raise ValueError( raise ValueError(
f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} to create procedural memories." f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} to create procedural memories."
@@ -231,10 +231,10 @@ class Memory(MemoryBase):
if isinstance(messages, str): if isinstance(messages, str):
messages = [{"role": "user", "content": messages}] messages = [{"role": "user", "content": messages}]
elif isinstance(messages, dict): elif isinstance(messages, dict):
messages = [messages] messages = [messages]
elif not isinstance(messages, list): elif not isinstance(messages, list):
raise ValueError("messages must be str, dict, or list[dict]") raise ValueError("messages must be str, dict, or list[dict]")
@@ -255,7 +255,7 @@ class Memory(MemoryBase):
vector_store_result = future1.result() vector_store_result = future1.result()
graph_result = future2.result() graph_result = future2.result()
if self.api_version == "v1.0": if self.api_version == "v1.0":
warnings.warn( warnings.warn(
"The current add API output format is deprecated. " "The current add API output format is deprecated. "
@@ -277,21 +277,21 @@ class Memory(MemoryBase):
def _add_to_vector_store(self, messages, metadata, filters, infer): def _add_to_vector_store(self, messages, metadata, filters, infer):
if not infer: if not infer:
returned_memories = [] returned_memories = []
for message_dict in messages: for message_dict in messages:
if not isinstance(message_dict, dict) or \ if (
message_dict.get("role") is None or \ not isinstance(message_dict, dict)
message_dict.get("content") is None: or message_dict.get("role") is None
or message_dict.get("content") is None
):
logger.warning(f"Skipping invalid message format: {message_dict}") logger.warning(f"Skipping invalid message format: {message_dict}")
continue continue
if message_dict["role"] == "system": if message_dict["role"] == "system":
continue continue
per_msg_meta = deepcopy(metadata) per_msg_meta = deepcopy(metadata)
per_msg_meta["role"] = message_dict["role"] per_msg_meta["role"] = message_dict["role"]
actor_name = message_dict.get("name") actor_name = message_dict.get("name")
if actor_name: if actor_name:
per_msg_meta["actor_id"] = actor_name per_msg_meta["actor_id"] = actor_name
@@ -311,8 +311,8 @@ class Memory(MemoryBase):
) )
return returned_memories return returned_memories
parsed_messages = parse_messages(messages) parsed_messages = parse_messages(messages)
if self.config.custom_fact_extraction_prompt: if self.config.custom_fact_extraction_prompt:
system_prompt = self.config.custom_fact_extraction_prompt system_prompt = self.config.custom_fact_extraction_prompt
user_prompt = f"Input:\n{parsed_messages}" user_prompt = f"Input:\n{parsed_messages}"
@@ -336,7 +336,7 @@ class Memory(MemoryBase):
retrieved_old_memory = [] retrieved_old_memory = []
new_message_embeddings = {} new_message_embeddings = {}
for new_mem in new_retrieved_facts: for new_mem in new_retrieved_facts:
messages_embeddings = self.embedding_model.embed(new_mem, "add") messages_embeddings = self.embedding_model.embed(new_mem, "add")
new_message_embeddings[new_mem] = messages_embeddings new_message_embeddings[new_mem] = messages_embeddings
existing_memories = self.vector_store.search( existing_memories = self.vector_store.search(
@@ -347,7 +347,7 @@ class Memory(MemoryBase):
) )
for mem in existing_memories: for mem in existing_memories:
retrieved_old_memory.append({"id": mem.id, "text": mem.payload["data"]}) retrieved_old_memory.append({"id": mem.id, "text": mem.payload["data"]})
unique_data = {} unique_data = {}
for item in retrieved_old_memory: for item in retrieved_old_memory:
unique_data[item["id"]] = item unique_data[item["id"]] = item
@@ -389,7 +389,7 @@ class Memory(MemoryBase):
if not action_text: if not action_text:
logging.info("Skipping memory entry because of empty `text` field.") logging.info("Skipping memory entry because of empty `text` field.")
continue continue
event_type = resp.get("event") event_type = resp.get("event")
if event_type == "ADD": if event_type == "ADD":
memory_id = self._create_memory( memory_id = self._create_memory(
@@ -405,16 +405,23 @@ class Memory(MemoryBase):
existing_embeddings=new_message_embeddings, existing_embeddings=new_message_embeddings,
metadata=deepcopy(metadata), metadata=deepcopy(metadata),
) )
returned_memories.append({ returned_memories.append(
"id": temp_uuid_mapping[resp.get("id")], "memory": action_text, {
"event": event_type, "previous_memory": resp.get("old_memory"), "id": temp_uuid_mapping[resp.get("id")],
}) "memory": action_text,
"event": event_type,
"previous_memory": resp.get("old_memory"),
}
)
elif event_type == "DELETE": elif event_type == "DELETE":
self._delete_memory(memory_id=temp_uuid_mapping[resp.get("id")]) self._delete_memory(memory_id=temp_uuid_mapping[resp.get("id")])
returned_memories.append({ returned_memories.append(
"id": temp_uuid_mapping[resp.get("id")], "memory": action_text, {
"event": event_type, "id": temp_uuid_mapping[resp.get("id")],
}) "memory": action_text,
"event": event_type,
}
)
elif event_type == "NONE": elif event_type == "NONE":
logging.info("NOOP for Memory.") logging.info("NOOP for Memory.")
except Exception as e: except Exception as e:
@@ -462,11 +469,8 @@ class Memory(MemoryBase):
"actor_id", "actor_id",
"role", "role",
] ]
core_and_promoted_keys = { core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
"data", "hash", "created_at", "updated_at", "id",
*promoted_payload_keys
}
result_item = MemoryItem( result_item = MemoryItem(
id=memory.id, id=memory.id,
@@ -479,18 +483,16 @@ class Memory(MemoryBase):
for key in promoted_payload_keys: for key in promoted_payload_keys:
if key in memory.payload: if key in memory.payload:
result_item[key] = memory.payload[key] result_item[key] = memory.payload[key]
additional_metadata = { additional_metadata = {k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys}
k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys
}
if additional_metadata: if additional_metadata:
result_item["metadata"] = additional_metadata result_item["metadata"] = additional_metadata
return result_item return result_item
def get_all( def get_all(
self, self,
*, *,
user_id: Optional[str] = None, user_id: Optional[str] = None,
agent_id: Optional[str] = None, agent_id: Optional[str] = None,
run_id: Optional[str] = None, run_id: Optional[str] = None,
@@ -505,7 +507,7 @@ class Memory(MemoryBase):
agent_id (str, optional): agent id agent_id (str, optional): agent id
run_id (str, optional): run id run_id (str, optional): run id
filters (dict, optional): Additional custom key-value filters to apply to the search. filters (dict, optional): Additional custom key-value filters to apply to the search.
These are merged with the ID-based scoping filters. For example, These are merged with the ID-based scoping filters. For example,
`filters={"actor_id": "some_user"}`. `filters={"actor_id": "some_user"}`.
limit (int, optional): The maximum number of memories to return. Defaults to 100. limit (int, optional): The maximum number of memories to return. Defaults to 100.
@@ -515,21 +517,16 @@ class Memory(MemoryBase):
it might return a direct list (see deprecation warning). it might return a direct list (see deprecation warning).
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", ...}]}` Example for v1.1+: `{"results": [{"id": "...", "memory": "...", ...}]}`
""" """
_, effective_filters = _build_filters_and_metadata( _, effective_filters = _build_filters_and_metadata(
user_id=user_id, user_id=user_id, agent_id=agent_id, run_id=run_id, input_filters=filters
agent_id=agent_id,
run_id=run_id,
input_filters=filters
) )
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")): if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError("At least one of 'user_id', 'agent_id', or 'run_id' must be specified.") raise ValueError("At least one of 'user_id', 'agent_id', or 'run_id' must be specified.")
capture_event( capture_event(
"mem0.get_all", "mem0.get_all", self, {"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "sync"}
self,
{"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "sync"}
) )
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
@@ -542,9 +539,9 @@ class Memory(MemoryBase):
[future_memories, future_graph_entities] if future_graph_entities else [future_memories] [future_memories, future_graph_entities] if future_graph_entities else [future_memories]
) )
all_memories_result = future_memories.result() all_memories_result = future_memories.result()
graph_entities_result = future_graph_entities.result() if future_graph_entities else None graph_entities_result = future_graph_entities.result() if future_graph_entities else None
if self.enable_graph: if self.enable_graph:
return {"results": all_memories_result, "relations": graph_entities_result} return {"results": all_memories_result, "relations": graph_entities_result}
@@ -556,26 +553,27 @@ class Memory(MemoryBase):
category=DeprecationWarning, category=DeprecationWarning,
stacklevel=2, stacklevel=2,
) )
return all_memories_result return all_memories_result
else: else:
return {"results": all_memories_result} return {"results": all_memories_result}
def _get_all_from_vector_store(self, filters, limit): def _get_all_from_vector_store(self, filters, limit):
memories_result = self.vector_store.list(filters=filters, limit=limit) memories_result = self.vector_store.list(filters=filters, limit=limit)
actual_memories = memories_result[0] if isinstance(memories_result, tuple) and len(memories_result) > 0 else memories_result actual_memories = (
memories_result[0] if isinstance(memories_result, tuple) and len(memories_result) > 0 else memories_result
)
promoted_payload_keys = [ promoted_payload_keys = [
"user_id", "agent_id", "run_id", "user_id",
"agent_id",
"run_id",
"actor_id", "actor_id",
"role", "role",
] ]
core_and_promoted_keys = { core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
"data", "hash", "created_at", "updated_at", "id",
*promoted_payload_keys
}
formatted_memories = [] formatted_memories = []
for mem in actual_memories: for mem in actual_memories:
memory_item_dict = MemoryItem( memory_item_dict = MemoryItem(
id=mem.id, id=mem.id,
memory=mem.payload["data"], memory=mem.payload["data"],
@@ -587,15 +585,13 @@ class Memory(MemoryBase):
for key in promoted_payload_keys: for key in promoted_payload_keys:
if key in mem.payload: if key in mem.payload:
memory_item_dict[key] = mem.payload[key] memory_item_dict[key] = mem.payload[key]
additional_metadata = { additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys}
k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys
}
if additional_metadata: if additional_metadata:
memory_item_dict["metadata"] = additional_metadata memory_item_dict["metadata"] = additional_metadata
formatted_memories.append(memory_item_dict) formatted_memories.append(memory_item_dict)
return formatted_memories return formatted_memories
def search( def search(
@@ -624,12 +620,9 @@ class Memory(MemoryBase):
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "score": 0.8, ...}]}` Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "score": 0.8, ...}]}`
""" """
_, effective_filters = _build_filters_and_metadata( _, effective_filters = _build_filters_and_metadata(
user_id=user_id, user_id=user_id, agent_id=agent_id, run_id=run_id, input_filters=filters
agent_id=agent_id,
run_id=run_id,
input_filters=filters
) )
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")): if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError("At least one of 'user_id', 'agent_id', or 'run_id' must be specified.") raise ValueError("At least one of 'user_id', 'agent_id', or 'run_id' must be specified.")
@@ -651,7 +644,7 @@ class Memory(MemoryBase):
original_memories = future_memories.result() original_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None graph_entities = future_graph_entities.result() if future_graph_entities else None
if self.enable_graph: if self.enable_graph:
return {"results": original_memories, "relations": graph_entities} return {"results": original_memories, "relations": graph_entities}
@@ -678,11 +671,8 @@ class Memory(MemoryBase):
"actor_id", "actor_id",
"role", "role",
] ]
core_and_promoted_keys = { core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
"data", "hash", "created_at", "updated_at", "id",
*promoted_payload_keys
}
original_memories = [] original_memories = []
for mem in memories: for mem in memories:
@@ -693,18 +683,16 @@ class Memory(MemoryBase):
created_at=mem.payload.get("created_at"), created_at=mem.payload.get("created_at"),
updated_at=mem.payload.get("updated_at"), updated_at=mem.payload.get("updated_at"),
score=mem.score, score=mem.score,
).model_dump() ).model_dump()
for key in promoted_payload_keys: for key in promoted_payload_keys:
if key in mem.payload: if key in mem.payload:
memory_item_dict[key] = mem.payload[key] memory_item_dict[key] = mem.payload[key]
additional_metadata = { additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys}
k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys
}
if additional_metadata: if additional_metadata:
memory_item_dict["metadata"] = additional_metadata memory_item_dict["metadata"] = additional_metadata
original_memories.append(memory_item_dict) original_memories.append(memory_item_dict)
return original_memories return original_memories
@@ -738,7 +726,7 @@ class Memory(MemoryBase):
self._delete_memory(memory_id) self._delete_memory(memory_id)
return {"message": "Memory deleted successfully!"} return {"message": "Memory deleted successfully!"}
def delete_all(self, user_id:Optional[str]=None, agent_id:Optional[str]=None, run_id:Optional[str]=None): def delete_all(self, user_id: Optional[str] = None, agent_id: Optional[str] = None, run_id: Optional[str] = None):
""" """
Delete all memories. Delete all memories.
@@ -860,11 +848,11 @@ class Memory(MemoryBase):
except Exception: except Exception:
logger.error(f"Error getting memory with ID {memory_id} during update.") logger.error(f"Error getting memory with ID {memory_id} during update.")
raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'") raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'")
prev_value = existing_memory.payload.get("data") prev_value = existing_memory.payload.get("data")
new_metadata = deepcopy(metadata) if metadata is not None else {} new_metadata = deepcopy(metadata) if metadata is not None else {}
new_metadata["data"] = data new_metadata["data"] = data
new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest() new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
new_metadata["created_at"] = existing_memory.payload.get("created_at") new_metadata["created_at"] = existing_memory.payload.get("created_at")
@@ -875,7 +863,7 @@ class Memory(MemoryBase):
if "agent_id" in existing_memory.payload: if "agent_id" in existing_memory.payload:
new_metadata["agent_id"] = existing_memory.payload["agent_id"] new_metadata["agent_id"] = existing_memory.payload["agent_id"]
if "run_id" in existing_memory.payload: if "run_id" in existing_memory.payload:
new_metadata["run_id"] = existing_memory.payload["run_id"] new_metadata["run_id"] = existing_memory.payload["run_id"]
if "actor_id" in existing_memory.payload: if "actor_id" in existing_memory.payload:
new_metadata["actor_id"] = existing_memory.payload["actor_id"] new_metadata["actor_id"] = existing_memory.payload["actor_id"]
if "role" in existing_memory.payload: if "role" in existing_memory.payload:
@@ -885,14 +873,14 @@ class Memory(MemoryBase):
embeddings = existing_embeddings[data] embeddings = existing_embeddings[data]
else: else:
embeddings = self.embedding_model.embed(data, "update") embeddings = self.embedding_model.embed(data, "update")
self.vector_store.update( self.vector_store.update(
vector_id=memory_id, vector_id=memory_id,
vector=embeddings, vector=embeddings,
payload=new_metadata, payload=new_metadata,
) )
logger.info(f"Updating memory with ID {memory_id=} with {data=}") logger.info(f"Updating memory with ID {memory_id=} with {data=}")
self.db.add_history( self.db.add_history(
memory_id, memory_id,
prev_value, prev_value,
@@ -1037,12 +1025,9 @@ class AsyncMemory(MemoryBase):
dict: A dictionary containing the result of the memory addition operation. dict: A dictionary containing the result of the memory addition operation.
""" """
processed_metadata, effective_filters = _build_filters_and_metadata( processed_metadata, effective_filters = _build_filters_and_metadata(
user_id=user_id, user_id=user_id, agent_id=agent_id, run_id=run_id, input_metadata=metadata
agent_id=agent_id,
run_id=run_id,
input_metadata=metadata
) )
if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value: if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value:
raise ValueError( raise ValueError(
f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} to create procedural memories." f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} to create procedural memories."
@@ -1050,15 +1035,17 @@ class AsyncMemory(MemoryBase):
if isinstance(messages, str): if isinstance(messages, str):
messages = [{"role": "user", "content": messages}] messages = [{"role": "user", "content": messages}]
elif isinstance(messages, dict): elif isinstance(messages, dict):
messages = [messages] messages = [messages]
elif not isinstance(messages, list): elif not isinstance(messages, list):
raise ValueError("messages must be str, dict, or list[dict]") raise ValueError("messages must be str, dict, or list[dict]")
if agent_id is not None and memory_type == MemoryType.PROCEDURAL.value: if agent_id is not None and memory_type == MemoryType.PROCEDURAL.value:
results = await self._create_procedural_memory(messages, metadata=processed_metadata, prompt=prompt, llm=llm) results = await self._create_procedural_memory(
messages, metadata=processed_metadata, prompt=prompt, llm=llm
)
return results return results
if self.config.llm.config.get("enable_vision"): if self.config.llm.config.get("enable_vision"):
@@ -1066,7 +1053,9 @@ class AsyncMemory(MemoryBase):
else: else:
messages = parse_vision_messages(messages) messages = parse_vision_messages(messages)
vector_store_task = asyncio.create_task(self._add_to_vector_store(messages, processed_metadata, effective_filters, infer)) vector_store_task = asyncio.create_task(
self._add_to_vector_store(messages, processed_metadata, effective_filters, infer)
)
graph_task = asyncio.create_task(self._add_to_graph(messages, effective_filters)) graph_task = asyncio.create_task(self._add_to_graph(messages, effective_filters))
vector_store_result, graph_result = await asyncio.gather(vector_store_task, graph_task) vector_store_result, graph_result = await asyncio.gather(vector_store_task, graph_task)
@@ -1090,8 +1079,8 @@ class AsyncMemory(MemoryBase):
return {"results": vector_store_result} return {"results": vector_store_result}
async def _add_to_vector_store( async def _add_to_vector_store(
self, self,
messages: list, messages: list,
metadata: dict, metadata: dict,
filters: dict, filters: dict,
infer: bool, infer: bool,
@@ -1099,9 +1088,11 @@ class AsyncMemory(MemoryBase):
if not infer: if not infer:
returned_memories = [] returned_memories = []
for message_dict in messages: for message_dict in messages:
if not isinstance(message_dict, dict) or \ if (
message_dict.get("role") is None or \ not isinstance(message_dict, dict)
message_dict.get("content") is None: or message_dict.get("role") is None
or message_dict.get("content") is None
):
logger.warning(f"Skipping invalid message format (async): {message_dict}") logger.warning(f"Skipping invalid message format (async): {message_dict}")
continue continue
@@ -1110,20 +1101,24 @@ class AsyncMemory(MemoryBase):
per_msg_meta = deepcopy(metadata) per_msg_meta = deepcopy(metadata)
per_msg_meta["role"] = message_dict["role"] per_msg_meta["role"] = message_dict["role"]
actor_name = message_dict.get("name") actor_name = message_dict.get("name")
if actor_name: if actor_name:
per_msg_meta["actor_id"] = actor_name per_msg_meta["actor_id"] = actor_name
msg_content = message_dict["content"] msg_content = message_dict["content"]
msg_embeddings = await asyncio.to_thread(self.embedding_model.embed, msg_content, "add") msg_embeddings = await asyncio.to_thread(self.embedding_model.embed, msg_content, "add")
mem_id = await self._create_memory(msg_content, msg_embeddings, per_msg_meta) mem_id = await self._create_memory(msg_content, msg_embeddings, per_msg_meta)
returned_memories.append({ returned_memories.append(
"id": mem_id, "memory": msg_content, "event": "ADD", {
"actor_id": actor_name if actor_name else None, "id": mem_id,
"role": message_dict["role"], "memory": msg_content,
}) "event": "ADD",
"actor_id": actor_name if actor_name else None,
"role": message_dict["role"],
}
)
return returned_memories return returned_memories
parsed_messages = parse_messages(messages) parsed_messages = parse_messages(messages)
@@ -1142,17 +1137,21 @@ class AsyncMemory(MemoryBase):
response = remove_code_blocks(response) response = remove_code_blocks(response)
new_retrieved_facts = json.loads(response)["facts"] new_retrieved_facts = json.loads(response)["facts"]
except Exception as e: except Exception as e:
logging.error(f"Error in new_retrieved_facts: {e}"); new_retrieved_facts = [] logging.error(f"Error in new_retrieved_facts: {e}")
new_retrieved_facts = []
retrieved_old_memory = [] retrieved_old_memory = []
new_message_embeddings = {} new_message_embeddings = {}
async def process_fact_for_search(new_mem_content): async def process_fact_for_search(new_mem_content):
embeddings = await asyncio.to_thread(self.embedding_model.embed, new_mem_content, "add") embeddings = await asyncio.to_thread(self.embedding_model.embed, new_mem_content, "add")
new_message_embeddings[new_mem_content] = embeddings new_message_embeddings[new_mem_content] = embeddings
existing_mems = await asyncio.to_thread( existing_mems = await asyncio.to_thread(
self.vector_store.search, query=new_mem_content, vectors=embeddings, self.vector_store.search,
limit=5, filters=filters, # 'filters' is query_filters_for_inference query=new_mem_content,
vectors=embeddings,
limit=5,
filters=filters, # 'filters' is query_filters_for_inference
) )
return [{"id": mem.id, "text": mem.payload["data"]} for mem in existing_mems] return [{"id": mem.id, "text": mem.payload["data"]} for mem in existing_mems]
@@ -1160,9 +1159,10 @@ class AsyncMemory(MemoryBase):
search_results_list = await asyncio.gather(*search_tasks) search_results_list = await asyncio.gather(*search_tasks)
for result_group in search_results_list: for result_group in search_results_list:
retrieved_old_memory.extend(result_group) retrieved_old_memory.extend(result_group)
unique_data = {} unique_data = {}
for item in retrieved_old_memory: unique_data[item["id"]] = item for item in retrieved_old_memory:
unique_data[item["id"]] = item
retrieved_old_memory = list(unique_data.values()) retrieved_old_memory = list(unique_data.values())
logging.info(f"Total existing memories: {len(retrieved_old_memory)}") logging.info(f"Total existing memories: {len(retrieved_old_memory)}")
temp_uuid_mapping = {} temp_uuid_mapping = {}
@@ -1180,35 +1180,45 @@ class AsyncMemory(MemoryBase):
response_format={"type": "json_object"}, response_format={"type": "json_object"},
) )
except Exception as e: except Exception as e:
logging.error(f"Error in new memory actions response: {e}"); response = "" logging.error(f"Error in new memory actions response: {e}")
response = ""
try: try:
response = remove_code_blocks(response) response = remove_code_blocks(response)
new_memories_with_actions = json.loads(response) new_memories_with_actions = json.loads(response)
except Exception as e: except Exception as e:
logging.error(f"Invalid JSON response: {e}"); new_memories_with_actions = {} logging.error(f"Invalid JSON response: {e}")
new_memories_with_actions = {}
returned_memories = [] returned_memories = []
try: try:
memory_tasks = [] memory_tasks = []
for resp in new_memories_with_actions.get("memory", []): for resp in new_memories_with_actions.get("memory", []):
logging.info(resp) logging.info(resp)
try: try:
action_text = resp.get("text") action_text = resp.get("text")
if not action_text: continue if not action_text:
continue
event_type = resp.get("event") event_type = resp.get("event")
if event_type == "ADD": if event_type == "ADD":
task = asyncio.create_task(self._create_memory( task = asyncio.create_task(
data=action_text, existing_embeddings=new_message_embeddings, self._create_memory(
metadata=deepcopy(metadata) data=action_text,
)) existing_embeddings=new_message_embeddings,
metadata=deepcopy(metadata),
)
)
memory_tasks.append((task, resp, "ADD", None)) memory_tasks.append((task, resp, "ADD", None))
elif event_type == "UPDATE": elif event_type == "UPDATE":
task = asyncio.create_task(self._update_memory( task = asyncio.create_task(
memory_id=temp_uuid_mapping[resp["id"]], data=action_text, self._update_memory(
existing_embeddings=new_message_embeddings, metadata=deepcopy(metadata) memory_id=temp_uuid_mapping[resp["id"]],
)) data=action_text,
existing_embeddings=new_message_embeddings,
metadata=deepcopy(metadata),
)
)
memory_tasks.append((task, resp, "UPDATE", temp_uuid_mapping[resp["id"]])) memory_tasks.append((task, resp, "UPDATE", temp_uuid_mapping[resp["id"]]))
elif event_type == "DELETE": elif event_type == "DELETE":
task = asyncio.create_task(self._delete_memory(memory_id=temp_uuid_mapping[resp.get("id")])) task = asyncio.create_task(self._delete_memory(memory_id=temp_uuid_mapping[resp.get("id")]))
@@ -1217,31 +1227,30 @@ class AsyncMemory(MemoryBase):
logging.info("NOOP for Memory (async).") logging.info("NOOP for Memory (async).")
except Exception as e: except Exception as e:
logging.error(f"Error processing memory action (async): {resp}, Error: {e}") logging.error(f"Error processing memory action (async): {resp}, Error: {e}")
for task, resp, event_type, mem_id in memory_tasks: for task, resp, event_type, mem_id in memory_tasks:
try: try:
result_id = await task result_id = await task
if event_type == "ADD": if event_type == "ADD":
returned_memories.append({ returned_memories.append({"id": result_id, "memory": resp.get("text"), "event": event_type})
"id": result_id, "memory": resp.get("text"), "event": event_type
})
elif event_type == "UPDATE": elif event_type == "UPDATE":
returned_memories.append({ returned_memories.append(
"id": mem_id, "memory": resp.get("text"), {
"event": event_type, "previous_memory": resp.get("old_memory") "id": mem_id,
}) "memory": resp.get("text"),
"event": event_type,
"previous_memory": resp.get("old_memory"),
}
)
elif event_type == "DELETE": elif event_type == "DELETE":
returned_memories.append({ returned_memories.append({"id": mem_id, "memory": resp.get("text"), "event": event_type})
"id": mem_id, "memory": resp.get("text"), "event": event_type
})
except Exception as e: except Exception as e:
logging.error(f"Error awaiting memory task (async): {e}") logging.error(f"Error awaiting memory task (async): {e}")
except Exception as e: except Exception as e:
logging.error(f"Error in memory processing loop (async): {e}") logging.error(f"Error in memory processing loop (async): {e}")
capture_event( capture_event(
"mem0.add", self, "mem0.add", self, {"version": self.api_version, "keys": list(filters.keys()), "sync_type": "async"}
{"version": self.api_version, "keys": list(filters.keys()), "sync_type": "async"}
) )
return returned_memories return returned_memories
@@ -1272,17 +1281,14 @@ class AsyncMemory(MemoryBase):
return None return None
promoted_payload_keys = [ promoted_payload_keys = [
"user_id", "user_id",
"agent_id", "agent_id",
"run_id", "run_id",
"actor_id", "actor_id",
"role", "role",
] ]
core_and_promoted_keys = { core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
"data", "hash", "created_at", "updated_at", "id",
*promoted_payload_keys
}
result_item = MemoryItem( result_item = MemoryItem(
id=memory.id, id=memory.id,
@@ -1295,18 +1301,16 @@ class AsyncMemory(MemoryBase):
for key in promoted_payload_keys: for key in promoted_payload_keys:
if key in memory.payload: if key in memory.payload:
result_item[key] = memory.payload[key] result_item[key] = memory.payload[key]
additional_metadata = { additional_metadata = {k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys}
k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys
}
if additional_metadata: if additional_metadata:
result_item["metadata"] = additional_metadata result_item["metadata"] = additional_metadata
return result_item return result_item
async def get_all( async def get_all(
self, self,
*, *,
user_id: Optional[str] = None, user_id: Optional[str] = None,
agent_id: Optional[str] = None, agent_id: Optional[str] = None,
run_id: Optional[str] = None, run_id: Optional[str] = None,
@@ -1314,41 +1318,36 @@ class AsyncMemory(MemoryBase):
limit: int = 100, limit: int = 100,
): ):
""" """
List all memories. List all memories.
Args: Args:
user_id (str, optional): user id user_id (str, optional): user id
agent_id (str, optional): agent id agent_id (str, optional): agent id
run_id (str, optional): run id run_id (str, optional): run id
filters (dict, optional): Additional custom key-value filters to apply to the search. filters (dict, optional): Additional custom key-value filters to apply to the search.
These are merged with the ID-based scoping filters. For example, These are merged with the ID-based scoping filters. For example,
`filters={"actor_id": "some_user"}`. `filters={"actor_id": "some_user"}`.
limit (int, optional): The maximum number of memories to return. Defaults to 100. limit (int, optional): The maximum number of memories to return. Defaults to 100.
Returns: Returns:
dict: A dictionary containing a list of memories under the "results" key, dict: A dictionary containing a list of memories under the "results" key,
and potentially "relations" if graph store is enabled. For API v1.0, and potentially "relations" if graph store is enabled. For API v1.0,
it might return a direct list (see deprecation warning). it might return a direct list (see deprecation warning).
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", ...}]}` Example for v1.1+: `{"results": [{"id": "...", "memory": "...", ...}]}`
""" """
_, effective_filters = _build_filters_and_metadata( _, effective_filters = _build_filters_and_metadata(
user_id=user_id, user_id=user_id, agent_id=agent_id, run_id=run_id, input_filters=filters
agent_id=agent_id,
run_id=run_id,
input_filters=filters
) )
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")): if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError( raise ValueError(
"When 'conversation_id' is not provided (classic mode), " "When 'conversation_id' is not provided (classic mode), "
"at least one of 'user_id', 'agent_id', or 'run_id' must be specified for get_all." "at least one of 'user_id', 'agent_id', or 'run_id' must be specified for get_all."
) )
capture_event( capture_event(
"mem0.get_all", "mem0.get_all", self, {"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "async"}
self,
{"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "async"}
) )
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
@@ -1361,9 +1360,9 @@ class AsyncMemory(MemoryBase):
[future_memories, future_graph_entities] if future_graph_entities else [future_memories] [future_memories, future_graph_entities] if future_graph_entities else [future_memories]
) )
all_memories_result = future_memories.result() all_memories_result = future_memories.result()
graph_entities_result = future_graph_entities.result() if future_graph_entities else None graph_entities_result = future_graph_entities.result() if future_graph_entities else None
if self.enable_graph: if self.enable_graph:
return {"results": all_memories_result, "relations": graph_entities_result} return {"results": all_memories_result, "relations": graph_entities_result}
@@ -1381,20 +1380,21 @@ class AsyncMemory(MemoryBase):
async def _get_all_from_vector_store(self, filters, limit): async def _get_all_from_vector_store(self, filters, limit):
memories_result = await asyncio.to_thread(self.vector_store.list, filters=filters, limit=limit) memories_result = await asyncio.to_thread(self.vector_store.list, filters=filters, limit=limit)
actual_memories = memories_result[0] if isinstance(memories_result, tuple) and len(memories_result) > 0 else memories_result actual_memories = (
memories_result[0] if isinstance(memories_result, tuple) and len(memories_result) > 0 else memories_result
)
promoted_payload_keys = [ promoted_payload_keys = [
"user_id", "agent_id", "run_id", "user_id",
"agent_id",
"run_id",
"actor_id", "actor_id",
"role", "role",
] ]
core_and_promoted_keys = { core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
"data", "hash", "created_at", "updated_at", "id",
*promoted_payload_keys
}
formatted_memories = [] formatted_memories = []
for mem in actual_memories: for mem in actual_memories:
memory_item_dict = MemoryItem( memory_item_dict = MemoryItem(
id=mem.id, id=mem.id,
memory=mem.payload["data"], memory=mem.payload["data"],
@@ -1406,15 +1406,13 @@ class AsyncMemory(MemoryBase):
for key in promoted_payload_keys: for key in promoted_payload_keys:
if key in mem.payload: if key in mem.payload:
memory_item_dict[key] = mem.payload[key] memory_item_dict[key] = mem.payload[key]
additional_metadata = { additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys}
k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys
}
if additional_metadata: if additional_metadata:
memory_item_dict["metadata"] = additional_metadata memory_item_dict["metadata"] = additional_metadata
formatted_memories.append(memory_item_dict) formatted_memories.append(memory_item_dict)
return formatted_memories return formatted_memories
async def search( async def search(
@@ -1442,16 +1440,13 @@ class AsyncMemory(MemoryBase):
and potentially "relations" if graph store is enabled. and potentially "relations" if graph store is enabled.
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "score": 0.8, ...}]}` Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "score": 0.8, ...}]}`
""" """
_, effective_filters = _build_filters_and_metadata( _, effective_filters = _build_filters_and_metadata(
user_id=user_id, user_id=user_id, agent_id=agent_id, run_id=run_id, input_filters=filters
agent_id=agent_id,
run_id=run_id,
input_filters=filters
) )
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")): if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError("at least one of 'user_id', 'agent_id', or 'run_id' must be specified ") raise ValueError("at least one of 'user_id', 'agent_id', or 'run_id' must be specified ")
capture_event( capture_event(
"mem0.search", "mem0.search",
@@ -1460,22 +1455,20 @@ class AsyncMemory(MemoryBase):
) )
vector_store_task = asyncio.create_task(self._search_vector_store(query, effective_filters, limit)) vector_store_task = asyncio.create_task(self._search_vector_store(query, effective_filters, limit))
graph_task = None graph_task = None
if self.enable_graph: if self.enable_graph:
if hasattr(self.graph.search, "__await__"): # Check if graph search is async if hasattr(self.graph.search, "__await__"): # Check if graph search is async
graph_task = asyncio.create_task(self.graph.search(query, effective_filters, limit)) graph_task = asyncio.create_task(self.graph.search(query, effective_filters, limit))
else: else:
graph_task = asyncio.create_task( graph_task = asyncio.create_task(asyncio.to_thread(self.graph.search, query, effective_filters, limit))
asyncio.to_thread(self.graph.search, query, effective_filters, limit)
)
if graph_task: if graph_task:
original_memories, graph_entities = await asyncio.gather(vector_store_task, graph_task) original_memories, graph_entities = await asyncio.gather(vector_store_task, graph_task)
else: else:
original_memories = await vector_store_task original_memories = await vector_store_task
graph_entities = None graph_entities = None
if self.enable_graph: if self.enable_graph:
return {"results": original_memories, "relations": graph_entities} return {"results": original_memories, "relations": graph_entities}
@@ -1504,11 +1497,8 @@ class AsyncMemory(MemoryBase):
"actor_id", "actor_id",
"role", "role",
] ]
core_and_promoted_keys = { core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
"data", "hash", "created_at", "updated_at", "id",
*promoted_payload_keys
}
original_memories = [] original_memories = []
for mem in memories: for mem in memories:
@@ -1518,19 +1508,17 @@ class AsyncMemory(MemoryBase):
hash=mem.payload.get("hash"), hash=mem.payload.get("hash"),
created_at=mem.payload.get("created_at"), created_at=mem.payload.get("created_at"),
updated_at=mem.payload.get("updated_at"), updated_at=mem.payload.get("updated_at"),
score=mem.score, score=mem.score,
).model_dump() ).model_dump()
for key in promoted_payload_keys: for key in promoted_payload_keys:
if key in mem.payload: if key in mem.payload:
memory_item_dict[key] = mem.payload[key] memory_item_dict[key] = mem.payload[key]
additional_metadata = { additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys}
k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys
}
if additional_metadata: if additional_metadata:
memory_item_dict["metadata"] = additional_metadata memory_item_dict["metadata"] = additional_metadata
original_memories.append(memory_item_dict) original_memories.append(memory_item_dict)
return original_memories return original_memories
@@ -1650,7 +1638,7 @@ class AsyncMemory(MemoryBase):
capture_event("mem0._create_memory", self, {"memory_id": memory_id, "sync_type": "async"}) capture_event("mem0._create_memory", self, {"memory_id": memory_id, "sync_type": "async"})
return memory_id return memory_id
async def _create_procedural_memory(self, messages, metadata=None,llm=None ,prompt=None): async def _create_procedural_memory(self, messages, metadata=None, llm=None, prompt=None):
""" """
Create a procedural memory asynchronously Create a procedural memory asynchronously
@@ -1709,11 +1697,11 @@ class AsyncMemory(MemoryBase):
except Exception: except Exception:
logger.error(f"Error getting memory with ID {memory_id} during update.") logger.error(f"Error getting memory with ID {memory_id} during update.")
raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'") raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'")
prev_value = existing_memory.payload.get("data") prev_value = existing_memory.payload.get("data")
new_metadata = deepcopy(metadata) if metadata is not None else {} new_metadata = deepcopy(metadata) if metadata is not None else {}
new_metadata["data"] = data new_metadata["data"] = data
new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest() new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
new_metadata["created_at"] = existing_memory.payload.get("created_at") new_metadata["created_at"] = existing_memory.payload.get("created_at")
@@ -1725,8 +1713,7 @@ class AsyncMemory(MemoryBase):
new_metadata["agent_id"] = existing_memory.payload["agent_id"] new_metadata["agent_id"] = existing_memory.payload["agent_id"]
if "run_id" in existing_memory.payload: if "run_id" in existing_memory.payload:
new_metadata["run_id"] = existing_memory.payload["run_id"] new_metadata["run_id"] = existing_memory.payload["run_id"]
if "actor_id" in existing_memory.payload: if "actor_id" in existing_memory.payload:
new_metadata["actor_id"] = existing_memory.payload["actor_id"] new_metadata["actor_id"] = existing_memory.payload["actor_id"]
if "role" in existing_memory.payload: if "role" in existing_memory.payload:
@@ -1736,7 +1723,7 @@ class AsyncMemory(MemoryBase):
embeddings = existing_embeddings[data] embeddings = existing_embeddings[data]
else: else:
embeddings = await asyncio.to_thread(self.embedding_model.embed, data, "update") embeddings = await asyncio.to_thread(self.embedding_model.embed, data, "update")
await asyncio.to_thread( await asyncio.to_thread(
self.vector_store.update, self.vector_store.update,
vector_id=memory_id, vector_id=memory_id,
@@ -1744,7 +1731,7 @@ class AsyncMemory(MemoryBase):
payload=new_metadata, payload=new_metadata,
) )
logger.info(f"Updating memory with ID {memory_id=} with {data=}") logger.info(f"Updating memory with ID {memory_id=} with {data=}")
await asyncio.to_thread( await asyncio.to_thread(
self.db.add_history, self.db.add_history,
memory_id, memory_id,

View File

@@ -5,16 +5,12 @@ from mem0.memory.utils import format_entities
try: try:
from langchain_memgraph import Memgraph from langchain_memgraph import Memgraph
except ImportError: except ImportError:
raise ImportError( raise ImportError("langchain_memgraph is not installed. Please install it using pip install langchain-memgraph")
"langchain_memgraph is not installed. Please install it using pip install langchain-memgraph"
)
try: try:
from rank_bm25 import BM25Okapi from rank_bm25 import BM25Okapi
except ImportError: except ImportError:
raise ImportError( raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
"rank_bm25 is not installed. Please install it using pip install rank-bm25"
)
from mem0.graphs.tools import ( from mem0.graphs.tools import (
DELETE_MEMORY_STRUCT_TOOL_GRAPH, DELETE_MEMORY_STRUCT_TOOL_GRAPH,
@@ -74,22 +70,14 @@ class MemoryGraph:
filters (dict): A dictionary containing filters to be applied during the addition. filters (dict): A dictionary containing filters to be applied during the addition.
""" """
entity_type_map = self._retrieve_nodes_from_data(data, filters) entity_type_map = self._retrieve_nodes_from_data(data, filters)
to_be_added = self._establish_nodes_relations_from_data( to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map)
data, filters, entity_type_map search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
) to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters)
search_output = self._search_graph_db(
node_list=list(entity_type_map.keys()), filters=filters
)
to_be_deleted = self._get_delete_entities_from_search_output(
search_output, data, filters
)
# TODO: Batch queries with APOC plugin # TODO: Batch queries with APOC plugin
# TODO: Add more filter support # TODO: Add more filter support
deleted_entities = self._delete_entities(to_be_deleted, filters["user_id"]) deleted_entities = self._delete_entities(to_be_deleted, filters["user_id"])
added_entities = self._add_entities( added_entities = self._add_entities(to_be_added, filters["user_id"], entity_type_map)
to_be_added, filters["user_id"], entity_type_map
)
return {"deleted_entities": deleted_entities, "added_entities": added_entities} return {"deleted_entities": deleted_entities, "added_entities": added_entities}
@@ -108,16 +96,13 @@ class MemoryGraph:
- "entities": List of related graph data based on the query. - "entities": List of related graph data based on the query.
""" """
entity_type_map = self._retrieve_nodes_from_data(query, filters) entity_type_map = self._retrieve_nodes_from_data(query, filters)
search_output = self._search_graph_db( search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
node_list=list(entity_type_map.keys()), filters=filters
)
if not search_output: if not search_output:
return [] return []
search_outputs_sequence = [ search_outputs_sequence = [
[item["source"], item["relationship"], item["destination"]] [item["source"], item["relationship"], item["destination"]] for item in search_output
for item in search_output
] ]
bm25 = BM25Okapi(search_outputs_sequence) bm25 = BM25Okapi(search_outputs_sequence)
@@ -126,9 +111,7 @@ class MemoryGraph:
search_results = [] search_results = []
for item in reranked_results: for item in reranked_results:
search_results.append( search_results.append({"source": item[0], "relationship": item[1], "destination": item[2]})
{"source": item[0], "relationship": item[1], "destination": item[2]}
)
logger.info(f"Returned {len(search_results)} search results") logger.info(f"Returned {len(search_results)} search results")
@@ -161,9 +144,7 @@ class MemoryGraph:
RETURN n.name AS source, type(r) AS relationship, m.name AS target RETURN n.name AS source, type(r) AS relationship, m.name AS target
LIMIT $limit LIMIT $limit
""" """
results = self.graph.query( results = self.graph.query(query, params={"user_id": filters["user_id"], "limit": limit})
query, params={"user_id": filters["user_id"], "limit": limit}
)
final_results = [] final_results = []
for result in results: for result in results:
@@ -208,13 +189,8 @@ class MemoryGraph:
f"Error in search tool: {e}, llm_provider={self.llm_provider}, search_results={search_results}" f"Error in search tool: {e}, llm_provider={self.llm_provider}, search_results={search_results}"
) )
entity_type_map = { entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()}
k.lower().replace(" ", "_"): v.lower().replace(" ", "_") logger.debug(f"Entity type map: {entity_type_map}\n search_results={search_results}")
for k, v in entity_type_map.items()
}
logger.debug(
f"Entity type map: {entity_type_map}\n search_results={search_results}"
)
return entity_type_map return entity_type_map
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map): def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
@@ -223,9 +199,7 @@ class MemoryGraph:
messages = [ messages = [
{ {
"role": "system", "role": "system",
"content": EXTRACT_RELATIONS_PROMPT.replace( "content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]).replace(
"USER_ID", filters["user_id"]
).replace(
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}" "CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
), ),
}, },
@@ -235,9 +209,7 @@ class MemoryGraph:
messages = [ messages = [
{ {
"role": "system", "role": "system",
"content": EXTRACT_RELATIONS_PROMPT.replace( "content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]),
"USER_ID", filters["user_id"]
),
}, },
{ {
"role": "user", "role": "user",
@@ -304,9 +276,7 @@ class MemoryGraph:
def _get_delete_entities_from_search_output(self, search_output, data, filters): def _get_delete_entities_from_search_output(self, search_output, data, filters):
"""Get the entities to be deleted from the search output.""" """Get the entities to be deleted from the search output."""
search_output_string = format_entities(search_output) search_output_string = format_entities(search_output)
system_prompt, user_prompt = get_delete_messages( system_prompt, user_prompt = get_delete_messages(search_output_string, data, filters["user_id"])
search_output_string, data, filters["user_id"]
)
_tools = [DELETE_MEMORY_TOOL_GRAPH] _tools = [DELETE_MEMORY_TOOL_GRAPH]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]: if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
@@ -379,12 +349,8 @@ class MemoryGraph:
# search for the nodes with the closest embeddings; this is basically # search for the nodes with the closest embeddings; this is basically
# comparison of one embedding to all embeddings in a graph -> vector # comparison of one embedding to all embeddings in a graph -> vector
# search with cosine similarity metric # search with cosine similarity metric
source_node_search_result = self._search_source_node( source_node_search_result = self._search_source_node(source_embedding, user_id, threshold=0.9)
source_embedding, user_id, threshold=0.9 destination_node_search_result = self._search_destination_node(dest_embedding, user_id, threshold=0.9)
)
destination_node_search_result = self._search_destination_node(
dest_embedding, user_id, threshold=0.9
)
# TODO: Create a cypher query and common params for all the cases # TODO: Create a cypher query and common params for all the cases
if not destination_node_search_result and source_node_search_result: if not destination_node_search_result and source_node_search_result:
@@ -424,9 +390,7 @@ class MemoryGraph:
""" """
params = { params = {
"destination_id": destination_node_search_result[0][ "destination_id": destination_node_search_result[0]["id(destination_candidate)"],
"id(destination_candidate)"
],
"source_name": source, "source_name": source,
"source_embedding": source_embedding, "source_embedding": source_embedding,
"user_id": user_id, "user_id": user_id,
@@ -445,9 +409,7 @@ class MemoryGraph:
""" """
params = { params = {
"source_id": source_node_search_result[0]["id(source_candidate)"], "source_id": source_node_search_result[0]["id(source_candidate)"],
"destination_id": destination_node_search_result[0][ "destination_id": destination_node_search_result[0]["id(destination_candidate)"],
"id(destination_candidate)"
],
"user_id": user_id, "user_id": user_id,
} }
else: else:

View File

@@ -1,8 +1,8 @@
import logging
import sqlite3 import sqlite3
import threading import threading
import uuid import uuid
import logging from typing import Any, Dict, List, Optional
from typing import List, Dict, Any, Optional
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -23,9 +23,7 @@ class SQLiteManager:
""" """
with self._lock, self.connection: with self._lock, self.connection:
cur = self.connection.cursor() cur = self.connection.cursor()
cur.execute( cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='history'")
"SELECT name FROM sqlite_master WHERE type='table' AND name='history'"
)
if cur.fetchone() is None: if cur.fetchone() is None:
return # nothing to migrate return # nothing to migrate
@@ -51,13 +49,11 @@ class SQLiteManager:
logger.info("Migrating history table to new schema (no convo columns).") logger.info("Migrating history table to new schema (no convo columns).")
cur.execute("ALTER TABLE history RENAME TO history_old") cur.execute("ALTER TABLE history RENAME TO history_old")
self._create_history_table() self._create_history_table()
intersecting = list(expected_cols & old_cols) intersecting = list(expected_cols & old_cols)
cols_csv = ", ".join(intersecting) cols_csv = ", ".join(intersecting)
cur.execute( cur.execute(f"INSERT INTO history ({cols_csv}) SELECT {cols_csv} FROM history_old")
f"INSERT INTO history ({cols_csv}) SELECT {cols_csv} FROM history_old"
)
cur.execute("DROP TABLE history_old") cur.execute("DROP TABLE history_old")
def _create_history_table(self) -> None: def _create_history_table(self) -> None:

View File

@@ -9,8 +9,8 @@ import mem0
from mem0.memory.setup import get_or_create_user_id from mem0.memory.setup import get_or_create_user_id
MEM0_TELEMETRY = os.environ.get("MEM0_TELEMETRY", "True") MEM0_TELEMETRY = os.environ.get("MEM0_TELEMETRY", "True")
PROJECT_API_KEY="phc_hgJkUVJFYtmaJqrvf6CYN67TIQ8yhXAkWzUn9AMU4yX" PROJECT_API_KEY = "phc_hgJkUVJFYtmaJqrvf6CYN67TIQ8yhXAkWzUn9AMU4yX"
HOST="https://us.i.posthog.com" HOST = "https://us.i.posthog.com"
if isinstance(MEM0_TELEMETRY, str): if isinstance(MEM0_TELEMETRY, str):
MEM0_TELEMETRY = MEM0_TELEMETRY.lower() in ("true", "1", "yes") MEM0_TELEMETRY = MEM0_TELEMETRY.lower() in ("true", "1", "yes")

View File

@@ -98,9 +98,8 @@ class VectorStoreFactory:
return vector_store_instance(**config) return vector_store_instance(**config)
else: else:
raise ValueError(f"Unsupported VectorStore provider: {provider_name}") raise ValueError(f"Unsupported VectorStore provider: {provider_name}")
@classmethod @classmethod
def reset(cls, instance): def reset(cls, instance):
instance.reset() instance.reset()
return instance return instance

View File

@@ -377,4 +377,3 @@ class AzureAISearch(VectorStoreBase):
except Exception as e: except Exception as e:
logger.error(f"Error resetting index {self.index_name}: {e}") logger.error(f"Error resetting index {self.index_name}: {e}")
raise raise

View File

@@ -51,7 +51,7 @@ class VectorStoreBase(ABC):
def list(self, filters=None, limit=None): def list(self, filters=None, limit=None):
"""List all memories.""" """List all memories."""
pass pass
@abstractmethod @abstractmethod
def reset(self): def reset(self):
"""Reset by delete the collection and recreate it.""" """Reset by delete the collection and recreate it."""

View File

@@ -221,7 +221,7 @@ class ChromaDB(VectorStoreBase):
""" """
results = self.collection.get(where=filters, limit=limit) results = self.collection.get(where=filters, limit=limit)
return [self._parse_output(results)] return [self._parse_output(results)]
def reset(self): def reset(self):
"""Reset the index by deleting and recreating it.""" """Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...") logger.warning(f"Resetting index {self.collection_name}...")

View File

@@ -58,7 +58,12 @@ class ElasticsearchDB(VectorStoreBase):
"mappings": { "mappings": {
"properties": { "properties": {
"text": {"type": "text"}, "text": {"type": "text"},
"vector": {"type": "dense_vector", "dims": self.embedding_model_dims, "index": True, "similarity": "cosine"}, "vector": {
"type": "dense_vector",
"dims": self.embedding_model_dims,
"index": True,
"similarity": "cosine",
},
"metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}}, "metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}},
} }
}, },
@@ -222,7 +227,7 @@ class ElasticsearchDB(VectorStoreBase):
) )
return [results] return [results]
def reset(self): def reset(self):
"""Reset the index by deleting and recreating it.""" """Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...") logger.warning(f"Resetting index {self.collection_name}...")

View File

@@ -465,7 +465,7 @@ class FAISS(VectorStoreBase):
break break
return [results] return [results]
def reset(self): def reset(self):
"""Reset the index by deleting and recreating it.""" """Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...") logger.warning(f"Resetting index {self.collection_name}...")

View File

@@ -14,6 +14,7 @@ from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class OutputData(BaseModel): class OutputData(BaseModel):
id: Optional[str] # memory id id: Optional[str] # memory id
score: Optional[float] # distance score: Optional[float] # distance
@@ -162,10 +163,7 @@ class Langchain(VectorStoreBase):
if filters and "user_id" in filters: if filters and "user_id" in filters:
where_clause = {"user_id": filters["user_id"]} where_clause = {"user_id": filters["user_id"]}
result = self.client._collection.get( result = self.client._collection.get(where=where_clause, limit=limit)
where=where_clause,
limit=limit
)
# Convert the result to the expected format # Convert the result to the expected format
if result and isinstance(result, dict): if result and isinstance(result, dict):

View File

@@ -237,7 +237,7 @@ class MilvusDB(VectorStoreBase):
obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata")) obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata"))
memories.append(obj) memories.append(obj)
return [memories] return [memories]
def reset(self): def reset(self):
"""Reset the index by deleting and recreating it.""" """Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...") logger.warning(f"Resetting index {self.collection_name}...")

View File

@@ -1,6 +1,6 @@
import logging import logging
from typing import Any, Dict, List, Optional
import time import time
from typing import Any, Dict, List, Optional
try: try:
from opensearchpy import OpenSearch, RequestsHttpConnection from opensearchpy import OpenSearch, RequestsHttpConnection
@@ -34,7 +34,7 @@ class OpenSearchDB(VectorStoreBase):
use_ssl=config.use_ssl, use_ssl=config.use_ssl,
verify_certs=config.verify_certs, verify_certs=config.verify_certs,
connection_class=RequestsHttpConnection, connection_class=RequestsHttpConnection,
pool_maxsize=20 pool_maxsize=20,
) )
self.collection_name = config.collection_name self.collection_name = config.collection_name
@@ -69,9 +69,7 @@ class OpenSearchDB(VectorStoreBase):
def create_col(self, name: str, vector_size: int) -> None: def create_col(self, name: str, vector_size: int) -> None:
"""Create a new collection (index in OpenSearch).""" """Create a new collection (index in OpenSearch)."""
index_settings = { index_settings = {
"settings": { "settings": {"index.knn": True},
"index.knn": True
},
"mappings": { "mappings": {
"properties": { "properties": {
"vector_field": { "vector_field": {
@@ -82,7 +80,7 @@ class OpenSearchDB(VectorStoreBase):
"payload": {"type": "object"}, "payload": {"type": "object"},
"id": {"type": "keyword"}, "id": {"type": "keyword"},
} }
} },
} }
if not self.client.indices.exists(index=name): if not self.client.indices.exists(index=name):
@@ -102,9 +100,7 @@ class OpenSearchDB(VectorStoreBase):
except Exception: except Exception:
retry_count += 1 retry_count += 1
if retry_count == max_retries: if retry_count == max_retries:
raise TimeoutError( raise TimeoutError(f"Index {name} creation timed out after {max_retries} seconds")
f"Index {name} creation timed out after {max_retries} seconds"
)
time.sleep(0.5) time.sleep(0.5)
def insert( def insert(
@@ -145,10 +141,7 @@ class OpenSearchDB(VectorStoreBase):
} }
# Start building the full query # Start building the full query
query_body = { query_body = {"size": limit * 2, "query": None}
"size": limit * 2,
"query": None
}
# Prepare filter conditions if applicable # Prepare filter conditions if applicable
filter_clauses = [] filter_clauses = []
@@ -156,18 +149,11 @@ class OpenSearchDB(VectorStoreBase):
for key in ["user_id", "run_id", "agent_id"]: for key in ["user_id", "run_id", "agent_id"]:
value = filters.get(key) value = filters.get(key)
if value: if value:
filter_clauses.append({ filter_clauses.append({"term": {f"payload.{key}.keyword": value}})
"term": {f"payload.{key}.keyword": value}
})
# Combine knn with filters if needed # Combine knn with filters if needed
if filter_clauses: if filter_clauses:
query_body["query"] = { query_body["query"] = {"bool": {"must": knn_query, "filter": filter_clauses}}
"bool": {
"must": knn_query,
"filter": filter_clauses
}
}
else: else:
query_body["query"] = knn_query query_body["query"] = knn_query
@@ -176,11 +162,7 @@ class OpenSearchDB(VectorStoreBase):
hits = response["hits"]["hits"] hits = response["hits"]["hits"]
results = [ results = [
OutputData( OutputData(id=hit["_source"].get("id"), score=hit["_score"], payload=hit["_source"].get("payload", {}))
id=hit["_source"].get("id"),
score=hit["_score"],
payload=hit["_source"].get("payload", {})
)
for hit in hits for hit in hits
] ]
return results return results
@@ -188,13 +170,7 @@ class OpenSearchDB(VectorStoreBase):
def delete(self, vector_id: str) -> None: def delete(self, vector_id: str) -> None:
"""Delete a vector by custom ID.""" """Delete a vector by custom ID."""
# First, find the document by custom ID # First, find the document by custom ID
search_query = { search_query = {"query": {"term": {"id": vector_id}}}
"query": {
"term": {
"id": vector_id
}
}
}
response = self.client.search(index=self.collection_name, body=search_query) response = self.client.search(index=self.collection_name, body=search_query)
hits = response.get("hits", {}).get("hits", []) hits = response.get("hits", {}).get("hits", [])
@@ -207,18 +183,11 @@ class OpenSearchDB(VectorStoreBase):
# Delete using the actual document ID # Delete using the actual document ID
self.client.delete(index=self.collection_name, id=opensearch_id) self.client.delete(index=self.collection_name, id=opensearch_id)
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None: def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None:
"""Update a vector and its payload using the custom 'id' field.""" """Update a vector and its payload using the custom 'id' field."""
# First, find the document by custom ID # First, find the document by custom ID
search_query = { search_query = {"query": {"term": {"id": vector_id}}}
"query": {
"term": {
"id": vector_id
}
}
}
response = self.client.search(index=self.collection_name, body=search_query) response = self.client.search(index=self.collection_name, body=search_query)
hits = response.get("hits", {}).get("hits", []) hits = response.get("hits", {}).get("hits", [])
@@ -241,7 +210,6 @@ class OpenSearchDB(VectorStoreBase):
except Exception: except Exception:
pass pass
def get(self, vector_id: str) -> Optional[OutputData]: def get(self, vector_id: str) -> Optional[OutputData]:
"""Retrieve a vector by ID.""" """Retrieve a vector by ID."""
try: try:
@@ -251,13 +219,7 @@ class OpenSearchDB(VectorStoreBase):
self.create_col(self.collection_name, self.embedding_model_dims) self.create_col(self.collection_name, self.embedding_model_dims)
return None return None
search_query = { search_query = {"query": {"term": {"id": vector_id}}}
"query": {
"term": {
"id": vector_id
}
}
}
response = self.client.search(index=self.collection_name, body=search_query) response = self.client.search(index=self.collection_name, body=search_query)
hits = response["hits"]["hits"] hits = response["hits"]["hits"]
@@ -265,11 +227,7 @@ class OpenSearchDB(VectorStoreBase):
if not hits: if not hits:
return None return None
return OutputData( return OutputData(id=hits[0]["_source"].get("id"), score=1.0, payload=hits[0]["_source"].get("payload", {}))
id=hits[0]["_source"].get("id"),
score=1.0,
payload=hits[0]["_source"].get("payload", {})
)
except Exception as e: except Exception as e:
logger.error(f"Error retrieving vector {vector_id}: {str(e)}") logger.error(f"Error retrieving vector {vector_id}: {str(e)}")
return None return None
@@ -287,30 +245,19 @@ class OpenSearchDB(VectorStoreBase):
return self.client.indices.get(index=name) return self.client.indices.get(index=name)
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[OutputData]: def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[OutputData]:
try: try:
"""List all memories with optional filters.""" """List all memories with optional filters."""
query: Dict = { query: Dict = {"query": {"match_all": {}}}
"query": {
"match_all": {}
}
}
filter_clauses = [] filter_clauses = []
if filters: if filters:
for key in ["user_id", "run_id", "agent_id"]: for key in ["user_id", "run_id", "agent_id"]:
value = filters.get(key) value = filters.get(key)
if value: if value:
filter_clauses.append({ filter_clauses.append({"term": {f"payload.{key}.keyword": value}})
"term": {f"payload.{key}.keyword": value}
})
if filter_clauses: if filter_clauses:
query["query"] = { query["query"] = {"bool": {"filter": filter_clauses}}
"bool": {
"filter": filter_clauses
}
}
if limit: if limit:
query["size"] = limit query["size"] = limit
@@ -318,18 +265,15 @@ class OpenSearchDB(VectorStoreBase):
response = self.client.search(index=self.collection_name, body=query) response = self.client.search(index=self.collection_name, body=query)
hits = response["hits"]["hits"] hits = response["hits"]["hits"]
return [[ return [
OutputData( [
id=hit["_source"].get("id"), OutputData(id=hit["_source"].get("id"), score=1.0, payload=hit["_source"].get("payload", {}))
score=1.0, for hit in hits
payload=hit["_source"].get("payload", {}) ]
) ]
for hit in hits
]]
except Exception: except Exception:
return [] return []
def reset(self): def reset(self):
"""Reset the index by deleting and recreating it.""" """Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...") logger.warning(f"Resetting index {self.collection_name}...")

View File

@@ -286,7 +286,7 @@ class PGVector(VectorStoreBase):
self.cur.close() self.cur.close()
if hasattr(self, "conn"): if hasattr(self, "conn"):
self.conn.close() self.conn.close()
def reset(self): def reset(self):
"""Reset the index by deleting and recreating it.""" """Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...") logger.warning(f"Resetting index {self.collection_name}...")

View File

@@ -232,7 +232,7 @@ class Qdrant(VectorStoreBase):
with_vectors=False, with_vectors=False,
) )
return result return result
def reset(self): def reset(self):
"""Reset the index by deleting and recreating it.""" """Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...") logger.warning(f"Resetting index {self.collection_name}...")

View File

@@ -88,7 +88,7 @@ class RedisDB(VectorStoreBase):
The created index object. The created index object.
""" """
# Use provided parameters or fall back to instance attributes # Use provided parameters or fall back to instance attributes
collection_name = name or self.schema['index']['name'] collection_name = name or self.schema["index"]["name"]
embedding_dims = vector_size or self.embedding_model_dims embedding_dims = vector_size or self.embedding_model_dims
distance_metric = distance or "cosine" distance_metric = distance or "cosine"
@@ -237,17 +237,16 @@ class RedisDB(VectorStoreBase):
""" """
Reset the index by deleting and recreating it. Reset the index by deleting and recreating it.
""" """
collection_name = self.schema['index']['name'] collection_name = self.schema["index"]["name"]
logger.warning(f"Resetting index {collection_name}...") logger.warning(f"Resetting index {collection_name}...")
self.delete_col() self.delete_col()
self.index = SearchIndex.from_dict(self.schema) self.index = SearchIndex.from_dict(self.schema)
self.index.set_client(self.client) self.index.set_client(self.client)
self.index.create(overwrite=True) self.index.create(overwrite=True)
#or use
#self.create_col(collection_name, self.embedding_model_dims)
# or use
# self.create_col(collection_name, self.embedding_model_dims)
# Recreate the index with the same parameters # Recreate the index with the same parameters
self.create_col(collection_name, self.embedding_model_dims) self.create_col(collection_name, self.embedding_model_dims)

View File

@@ -229,7 +229,7 @@ class Supabase(VectorStoreBase):
records = self.collection.fetch(ids=ids) records = self.collection.fetch(ids=ids)
return [[OutputData(id=str(record[0]), score=None, payload=record[2]) for record in records]] return [[OutputData(id=str(record[0]), score=None, payload=record[2]) for record in records]]
def reset(self): def reset(self):
"""Reset the index by deleting and recreating it.""" """Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...") logger.warning(f"Resetting index {self.collection_name}...")

View File

@@ -285,10 +285,9 @@ class UpstashVector(VectorStoreBase):
- Per-namespace vector and pending vector counts - Per-namespace vector and pending vector counts
""" """
return self.client.info() return self.client.info()
def reset(self): def reset(self):
""" """
Reset the Upstash Vector index. Reset the Upstash Vector index.
""" """
self.delete_col() self.delete_col()

View File

@@ -308,7 +308,7 @@ class Weaviate(VectorStoreBase):
payload["id"] = str(obj.uuid).split("'")[0] payload["id"] = str(obj.uuid).split("'")[0]
results.append(OutputData(id=str(obj.uuid).split("'")[0], score=1.0, payload=payload)) results.append(OutputData(id=str(obj.uuid).split("'")[0], score=1.0, payload=payload))
return [results] return [results]
def reset(self): def reset(self):
"""Reset the index by deleting and recreating it.""" """Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...") logger.warning(f"Resetting index {self.collection_name}...")

View File

@@ -44,31 +44,14 @@ DEFAULT_CONFIG = {
"user": POSTGRES_USER, "user": POSTGRES_USER,
"password": POSTGRES_PASSWORD, "password": POSTGRES_PASSWORD,
"collection_name": POSTGRES_COLLECTION_NAME, "collection_name": POSTGRES_COLLECTION_NAME,
} },
}, },
"graph_store": { "graph_store": {
"provider": "neo4j", "provider": "neo4j",
"config": { "config": {"url": NEO4J_URI, "username": NEO4J_USERNAME, "password": NEO4J_PASSWORD},
"url": NEO4J_URI,
"username": NEO4J_USERNAME,
"password": NEO4J_PASSWORD
}
},
"llm": {
"provider": "openai",
"config": {
"api_key": OPENAI_API_KEY,
"temperature": 0.2,
"model": "gpt-4o"
}
},
"embedder": {
"provider": "openai",
"config": {
"api_key": OPENAI_API_KEY,
"model": "text-embedding-3-small"
}
}, },
"llm": {"provider": "openai", "config": {"api_key": OPENAI_API_KEY, "temperature": 0.2, "model": "gpt-4o"}},
"embedder": {"provider": "openai", "config": {"api_key": OPENAI_API_KEY, "model": "text-embedding-3-small"}},
"history_db_path": HISTORY_DB_PATH, "history_db_path": HISTORY_DB_PATH,
} }
@@ -115,9 +98,7 @@ def set_config(config: Dict[str, Any]):
def add_memory(memory_create: MemoryCreate): def add_memory(memory_create: MemoryCreate):
"""Store new memories.""" """Store new memories."""
if not any([memory_create.user_id, memory_create.agent_id, memory_create.run_id]): if not any([memory_create.user_id, memory_create.agent_id, memory_create.run_id]):
raise HTTPException( raise HTTPException(status_code=400, detail="At least one identifier (user_id, agent_id, run_id) is required.")
status_code=400, detail="At least one identifier (user_id, agent_id, run_id) is required."
)
params = {k: v for k, v in memory_create.model_dump().items() if v is not None and k != "messages"} params = {k: v for k, v in memory_create.model_dump().items() if v is not None and k != "messages"}
try: try:
@@ -138,7 +119,9 @@ def get_all_memories(
if not any([user_id, run_id, agent_id]): if not any([user_id, run_id, agent_id]):
raise HTTPException(status_code=400, detail="At least one identifier is required.") raise HTTPException(status_code=400, detail="At least one identifier is required.")
try: try:
params = {k: v for k, v in {"user_id": user_id, "run_id": run_id, "agent_id": agent_id}.items() if v is not None} params = {
k: v for k, v in {"user_id": user_id, "run_id": run_id, "agent_id": agent_id}.items() if v is not None
}
return MEMORY_INSTANCE.get_all(**params) return MEMORY_INSTANCE.get_all(**params)
except Exception as e: except Exception as e:
logging.exception("Error in get_all_memories:") logging.exception("Error in get_all_memories:")
@@ -207,7 +190,9 @@ def delete_all_memories(
if not any([user_id, run_id, agent_id]): if not any([user_id, run_id, agent_id]):
raise HTTPException(status_code=400, detail="At least one identifier is required.") raise HTTPException(status_code=400, detail="At least one identifier is required.")
try: try:
params = {k: v for k, v in {"user_id": user_id, "run_id": run_id, "agent_id": agent_id}.items() if v is not None} params = {
k: v for k, v in {"user_id": user_id, "run_id": run_id, "agent_id": agent_id}.items() if v is not None
}
MEMORY_INSTANCE.delete_all(**params) MEMORY_INSTANCE.delete_all(**params)
return {"message": "All relevant memories deleted"} return {"message": "All relevant memories deleted"}
except Exception as e: except Exception as e:
@@ -229,4 +214,4 @@ def reset_memory():
@app.get("/", summary="Redirect to the OpenAPI documentation", include_in_schema=False) @app.get("/", summary="Redirect to the OpenAPI documentation", include_in_schema=False)
def home(): def home():
"""Redirect to the OpenAPI documentation.""" """Redirect to the OpenAPI documentation."""
return RedirectResponse(url='/docs') return RedirectResponse(url="/docs")

View File

@@ -5,13 +5,15 @@ def test_get_update_memory_messages():
retrieved_old_memory_dict = [{"id": "1", "text": "old memory 1"}] retrieved_old_memory_dict = [{"id": "1", "text": "old memory 1"}]
response_content = ["new fact"] response_content = ["new fact"]
custom_update_memory_prompt = "custom prompt determining memory update" custom_update_memory_prompt = "custom prompt determining memory update"
## When custom update memory prompt is provided ## When custom update memory prompt is provided
## ##
result = prompts.get_update_memory_messages(retrieved_old_memory_dict, response_content, custom_update_memory_prompt) result = prompts.get_update_memory_messages(
retrieved_old_memory_dict, response_content, custom_update_memory_prompt
)
assert result.startswith(custom_update_memory_prompt) assert result.startswith(custom_update_memory_prompt)
## When custom update memory prompt is not provided ## When custom update memory prompt is not provided
## ##
result = prompts.get_update_memory_messages(retrieved_old_memory_dict, response_content, None) result = prompts.get_update_memory_messages(retrieved_old_memory_dict, response_content, None)
assert result.startswith(prompts.DEFAULT_UPDATE_MEMORY_PROMPT) assert result.startswith(prompts.DEFAULT_UPDATE_MEMORY_PROMPT)

View File

@@ -10,9 +10,7 @@ from mem0.embeddings.lmstudio import LMStudioEmbedding
def mock_lm_studio_client(): def mock_lm_studio_client():
with patch("mem0.embeddings.lmstudio.OpenAI") as mock_openai: with patch("mem0.embeddings.lmstudio.OpenAI") as mock_openai:
mock_client = Mock() mock_client = Mock()
mock_client.embeddings.create.return_value = Mock( mock_client.embeddings.create.return_value = Mock(data=[Mock(embedding=[0.1, 0.2, 0.3, 0.4, 0.5])])
data=[Mock(embedding=[0.1, 0.2, 0.3, 0.4, 0.5])]
)
mock_openai.return_value = mock_client mock_openai.return_value = mock_client
yield mock_client yield mock_client

View File

@@ -23,7 +23,9 @@ def test_embed_default_model(mock_openai_client):
result = embedder.embed("Hello world") result = embedder.embed("Hello world")
mock_openai_client.embeddings.create.assert_called_once_with(input=["Hello world"], model="text-embedding-3-small", dimensions = 1536) mock_openai_client.embeddings.create.assert_called_once_with(
input=["Hello world"], model="text-embedding-3-small", dimensions=1536
)
assert result == [0.1, 0.2, 0.3] assert result == [0.1, 0.2, 0.3]
@@ -37,7 +39,7 @@ def test_embed_custom_model(mock_openai_client):
result = embedder.embed("Test embedding") result = embedder.embed("Test embedding")
mock_openai_client.embeddings.create.assert_called_once_with( mock_openai_client.embeddings.create.assert_called_once_with(
input=["Test embedding"], model="text-embedding-2-medium", dimensions = 1024 input=["Test embedding"], model="text-embedding-2-medium", dimensions=1024
) )
assert result == [0.4, 0.5, 0.6] assert result == [0.4, 0.5, 0.6]
@@ -51,7 +53,9 @@ def test_embed_removes_newlines(mock_openai_client):
result = embedder.embed("Hello\nworld") result = embedder.embed("Hello\nworld")
mock_openai_client.embeddings.create.assert_called_once_with(input=["Hello world"], model="text-embedding-3-small", dimensions = 1536) mock_openai_client.embeddings.create.assert_called_once_with(
input=["Hello world"], model="text-embedding-3-small", dimensions=1536
)
assert result == [0.7, 0.8, 0.9] assert result == [0.7, 0.8, 0.9]
@@ -65,7 +69,7 @@ def test_embed_without_api_key_env_var(mock_openai_client):
result = embedder.embed("Testing API key") result = embedder.embed("Testing API key")
mock_openai_client.embeddings.create.assert_called_once_with( mock_openai_client.embeddings.create.assert_called_once_with(
input=["Testing API key"], model="text-embedding-3-small", dimensions = 1536 input=["Testing API key"], model="text-embedding-3-small", dimensions=1536
) )
assert result == [1.0, 1.1, 1.2] assert result == [1.0, 1.1, 1.2]
@@ -81,6 +85,6 @@ def test_embed_uses_environment_api_key(mock_openai_client, monkeypatch):
result = embedder.embed("Environment key test") result = embedder.embed("Environment key test")
mock_openai_client.embeddings.create.assert_called_once_with( mock_openai_client.embeddings.create.assert_called_once_with(
input=["Environment key test"], model="text-embedding-3-small", dimensions = 1536 input=["Environment key test"], model="text-embedding-3-small", dimensions=1536
) )
assert result == [1.3, 1.4, 1.5] assert result == [1.3, 1.4, 1.5]

View File

@@ -24,11 +24,20 @@ def mock_config():
with patch("mem0.configs.embeddings.base.BaseEmbedderConfig") as mock_config: with patch("mem0.configs.embeddings.base.BaseEmbedderConfig") as mock_config:
mock_config.return_value.vertex_credentials_json = "/path/to/credentials.json" mock_config.return_value.vertex_credentials_json = "/path/to/credentials.json"
yield mock_config yield mock_config
@pytest.fixture @pytest.fixture
def mock_embedding_types(): def mock_embedding_types():
return ["SEMANTIC_SIMILARITY", "CLASSIFICATION", "CLUSTERING", "RETRIEVAL_DOCUMENT", "RETRIEVAL_QUERY", "QUESTION_ANSWERING", "FACT_VERIFICATION", "CODE_RETRIEVAL_QUERY"] return [
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"RETRIEVAL_DOCUMENT",
"RETRIEVAL_QUERY",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
"CODE_RETRIEVAL_QUERY",
]
@pytest.fixture @pytest.fixture
@@ -79,30 +88,31 @@ def test_embed_custom_model(mock_text_embedding_model, mock_os_environ, mock_con
assert result == [0.4, 0.5, 0.6] assert result == [0.4, 0.5, 0.6]
@patch("mem0.embeddings.vertexai.TextEmbeddingModel") @patch("mem0.embeddings.vertexai.TextEmbeddingModel")
def test_embed_with_memory_action(mock_text_embedding_model, mock_os_environ, mock_config, mock_embedding_types, mock_text_embedding_input): def test_embed_with_memory_action(
mock_text_embedding_model, mock_os_environ, mock_config, mock_embedding_types, mock_text_embedding_input
):
mock_config.return_value.model = "text-embedding-004" mock_config.return_value.model = "text-embedding-004"
mock_config.return_value.embedding_dims = 256 mock_config.return_value.embedding_dims = 256
for embedding_type in mock_embedding_types: for embedding_type in mock_embedding_types:
mock_config.return_value.memory_add_embedding_type = embedding_type mock_config.return_value.memory_add_embedding_type = embedding_type
mock_config.return_value.memory_update_embedding_type = embedding_type mock_config.return_value.memory_update_embedding_type = embedding_type
mock_config.return_value.memory_search_embedding_type = embedding_type mock_config.return_value.memory_search_embedding_type = embedding_type
config = mock_config() config = mock_config()
embedder = VertexAIEmbedding(config) embedder = VertexAIEmbedding(config)
mock_text_embedding_model.from_pretrained.assert_called_with("text-embedding-004") mock_text_embedding_model.from_pretrained.assert_called_with("text-embedding-004")
for memory_action in ["add", "update", "search"]: for memory_action in ["add", "update", "search"]:
embedder.embed("Hello world", memory_action=memory_action) embedder.embed("Hello world", memory_action=memory_action)
mock_text_embedding_input.assert_called_with(text="Hello world", task_type=embedding_type) mock_text_embedding_input.assert_called_with(text="Hello world", task_type=embedding_type)
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.assert_called_with( mock_text_embedding_model.from_pretrained.return_value.get_embeddings.assert_called_with(
texts=[mock_text_embedding_input("Hello world", embedding_type)], output_dimensionality=256 texts=[mock_text_embedding_input("Hello world", embedding_type)], output_dimensionality=256
) )
@patch("mem0.embeddings.vertexai.os") @patch("mem0.embeddings.vertexai.os")
def test_credentials_from_environment(mock_os, mock_text_embedding_model, mock_config): def test_credentials_from_environment(mock_os, mock_text_embedding_model, mock_config):
@@ -137,15 +147,15 @@ def test_embed_with_different_dimensions(mock_text_embedding_model, mock_os_envi
result = embedder.embed("Large embedding test") result = embedder.embed("Large embedding test")
assert result == [0.1] * 1024 assert result == [0.1] * 1024
@patch("mem0.embeddings.vertexai.TextEmbeddingModel")
@patch("mem0.embeddings.vertexai.TextEmbeddingModel")
def test_invalid_memory_action(mock_text_embedding_model, mock_config): def test_invalid_memory_action(mock_text_embedding_model, mock_config):
mock_config.return_value.model = "text-embedding-004" mock_config.return_value.model = "text-embedding-004"
mock_config.return_value.embedding_dims = 256 mock_config.return_value.embedding_dims = 256
config = mock_config() config = mock_config()
embedder = VertexAIEmbedding(config) embedder = VertexAIEmbedding(config)
with pytest.raises(ValueError): with pytest.raises(ValueError):
embedder.embed("Hello world", memory_action="invalid_action") embedder.embed("Hello world", memory_action="invalid_action")

View File

@@ -127,4 +127,4 @@ def test_generate_with_http_proxies(default_headers):
api_version=None, api_version=None,
default_headers=default_headers, default_headers=default_headers,
) )
mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000") mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000")

View File

@@ -31,12 +31,12 @@ def test_deepseek_llm_base_url():
# case3: with config.deepseek_base_url # case3: with config.deepseek_base_url
config_base_url = "https://api.config.com/v1/" config_base_url = "https://api.config.com/v1/"
config = BaseLlmConfig( config = BaseLlmConfig(
model="deepseek-chat", model="deepseek-chat",
temperature=0.7, temperature=0.7,
max_tokens=100, max_tokens=100,
top_p=1.0, top_p=1.0,
api_key="api_key", api_key="api_key",
deepseek_base_url=config_base_url deepseek_base_url=config_base_url,
) )
llm = DeepSeekLLM(config) llm = DeepSeekLLM(config)
assert str(llm.client.base_url) == config_base_url assert str(llm.client.base_url) == config_base_url
@@ -99,16 +99,16 @@ def test_generate_response_with_tools(mock_deepseek_client):
response = llm.generate_response(messages, tools=tools) response = llm.generate_response(messages, tools=tools)
mock_deepseek_client.chat.completions.create.assert_called_once_with( mock_deepseek_client.chat.completions.create.assert_called_once_with(
model="deepseek-chat", model="deepseek-chat",
messages=messages, messages=messages,
temperature=0.7, temperature=0.7,
max_tokens=100, max_tokens=100,
top_p=1.0, top_p=1.0,
tools=tools, tools=tools,
tool_choice="auto" tool_choice="auto",
) )
assert response["content"] == "I've added the memory for you." assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1 assert len(response["tool_calls"]) == 1
assert response["tool_calls"][0]["name"] == "add_memory" assert response["tool_calls"][0]["name"] == "add_memory"
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."} assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}

View File

@@ -10,6 +10,7 @@ try:
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import BaseChatModel
except ImportError: except ImportError:
from unittest.mock import MagicMock from unittest.mock import MagicMock
BaseChatModel = MagicMock BaseChatModel = MagicMock
@@ -24,16 +25,11 @@ def mock_langchain_model():
def test_langchain_initialization(mock_langchain_model): def test_langchain_initialization(mock_langchain_model):
"""Test that LangchainLLM initializes correctly with a valid model.""" """Test that LangchainLLM initializes correctly with a valid model."""
# Create a config with the model instance directly # Create a config with the model instance directly
config = BaseLlmConfig( config = BaseLlmConfig(model=mock_langchain_model, temperature=0.7, max_tokens=100, api_key="test-api-key")
model=mock_langchain_model,
temperature=0.7,
max_tokens=100,
api_key="test-api-key"
)
# Initialize the LangchainLLM # Initialize the LangchainLLM
llm = LangchainLLM(config) llm = LangchainLLM(config)
# Verify the model was correctly assigned # Verify the model was correctly assigned
assert llm.langchain_model == mock_langchain_model assert llm.langchain_model == mock_langchain_model
@@ -41,35 +37,30 @@ def test_langchain_initialization(mock_langchain_model):
def test_generate_response(mock_langchain_model): def test_generate_response(mock_langchain_model):
"""Test that generate_response correctly processes messages and returns a response.""" """Test that generate_response correctly processes messages and returns a response."""
# Create a config with the model instance # Create a config with the model instance
config = BaseLlmConfig( config = BaseLlmConfig(model=mock_langchain_model, temperature=0.7, max_tokens=100, api_key="test-api-key")
model=mock_langchain_model,
temperature=0.7,
max_tokens=100,
api_key="test-api-key"
)
# Initialize the LangchainLLM # Initialize the LangchainLLM
llm = LangchainLLM(config) llm = LangchainLLM(config)
# Create test messages # Create test messages
messages = [ messages = [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"}, {"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well! How can I help you?"}, {"role": "assistant", "content": "I'm doing well! How can I help you?"},
{"role": "user", "content": "Tell me a joke."} {"role": "user", "content": "Tell me a joke."},
] ]
# Get response # Get response
response = llm.generate_response(messages) response = llm.generate_response(messages)
# Verify the correct message format was passed to the model # Verify the correct message format was passed to the model
expected_langchain_messages = [ expected_langchain_messages = [
("system", "You are a helpful assistant."), ("system", "You are a helpful assistant."),
("human", "Hello, how are you?"), ("human", "Hello, how are you?"),
("ai", "I'm doing well! How can I help you?"), ("ai", "I'm doing well! How can I help you?"),
("human", "Tell me a joke.") ("human", "Tell me a joke."),
] ]
mock_langchain_model.invoke.assert_called_once() mock_langchain_model.invoke.assert_called_once()
# Extract the first argument of the first call # Extract the first argument of the first call
actual_messages = mock_langchain_model.invoke.call_args[0][0] actual_messages = mock_langchain_model.invoke.call_args[0][0]
@@ -79,25 +70,15 @@ def test_generate_response(mock_langchain_model):
def test_invalid_model(): def test_invalid_model():
"""Test that LangchainLLM raises an error with an invalid model.""" """Test that LangchainLLM raises an error with an invalid model."""
config = BaseLlmConfig( config = BaseLlmConfig(model="not-a-valid-model-instance", temperature=0.7, max_tokens=100, api_key="test-api-key")
model="not-a-valid-model-instance",
temperature=0.7,
max_tokens=100,
api_key="test-api-key"
)
with pytest.raises(ValueError, match="`model` must be an instance of BaseChatModel"): with pytest.raises(ValueError, match="`model` must be an instance of BaseChatModel"):
LangchainLLM(config) LangchainLLM(config)
def test_missing_model(): def test_missing_model():
"""Test that LangchainLLM raises an error when model is None.""" """Test that LangchainLLM raises an error when model is None."""
config = BaseLlmConfig( config = BaseLlmConfig(model=None, temperature=0.7, max_tokens=100, api_key="test-api-key")
model=None,
temperature=0.7,
max_tokens=100,
api_key="test-api-key"
)
with pytest.raises(ValueError, match="`model` parameter is required"): with pytest.raises(ValueError, match="`model` parameter is required"):
LangchainLLM(config) LangchainLLM(config)

View File

@@ -11,9 +11,7 @@ def mock_lm_studio_client():
with patch("mem0.llms.lmstudio.OpenAI") as mock_openai: # Corrected path with patch("mem0.llms.lmstudio.OpenAI") as mock_openai: # Corrected path
mock_client = Mock() mock_client = Mock()
mock_client.chat.completions.create.return_value = Mock( mock_client.chat.completions.create.return_value = Mock(
choices=[ choices=[Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
]
) )
mock_openai.return_value = mock_client mock_openai.return_value = mock_client
yield mock_client yield mock_client

View File

@@ -10,18 +10,19 @@ def _setup_mocks(mocker):
"""Helper to setup common mocks for both sync and async fixtures""" """Helper to setup common mocks for both sync and async fixtures"""
mock_embedder = mocker.MagicMock() mock_embedder = mocker.MagicMock()
mock_embedder.return_value.embed.return_value = [0.1, 0.2, 0.3] mock_embedder.return_value.embed.return_value = [0.1, 0.2, 0.3]
mocker.patch('mem0.utils.factory.EmbedderFactory.create', mock_embedder) mocker.patch("mem0.utils.factory.EmbedderFactory.create", mock_embedder)
mock_vector_store = mocker.MagicMock() mock_vector_store = mocker.MagicMock()
mock_vector_store.return_value.search.return_value = [] mock_vector_store.return_value.search.return_value = []
mocker.patch('mem0.utils.factory.VectorStoreFactory.create', mocker.patch(
side_effect=[mock_vector_store.return_value, mocker.MagicMock()]) "mem0.utils.factory.VectorStoreFactory.create", side_effect=[mock_vector_store.return_value, mocker.MagicMock()]
)
mock_llm = mocker.MagicMock() mock_llm = mocker.MagicMock()
mocker.patch('mem0.utils.factory.LlmFactory.create', mock_llm) mocker.patch("mem0.utils.factory.LlmFactory.create", mock_llm)
mocker.patch('mem0.memory.storage.SQLiteManager', mocker.MagicMock()) mocker.patch("mem0.memory.storage.SQLiteManager", mocker.MagicMock())
return mock_llm, mock_vector_store return mock_llm, mock_vector_store
@@ -30,29 +31,26 @@ class TestAddToVectorStoreErrors:
def mock_memory(self, mocker): def mock_memory(self, mocker):
"""Fixture that returns a Memory instance with mocker-based mocks""" """Fixture that returns a Memory instance with mocker-based mocks"""
mock_llm, _ = _setup_mocks(mocker) mock_llm, _ = _setup_mocks(mocker)
memory = Memory() memory = Memory()
memory.config = mocker.MagicMock() memory.config = mocker.MagicMock()
memory.config.custom_fact_extraction_prompt = None memory.config.custom_fact_extraction_prompt = None
memory.config.custom_update_memory_prompt = None memory.config.custom_update_memory_prompt = None
memory.api_version = "v1.1" memory.api_version = "v1.1"
return memory return memory
def test_empty_llm_response_fact_extraction(self, mock_memory, caplog): def test_empty_llm_response_fact_extraction(self, mock_memory, caplog):
"""Test empty response from LLM during fact extraction""" """Test empty response from LLM during fact extraction"""
# Setup # Setup
mock_memory.llm.generate_response.return_value = "" mock_memory.llm.generate_response.return_value = ""
# Execute # Execute
with caplog.at_level(logging.ERROR): with caplog.at_level(logging.ERROR):
result = mock_memory._add_to_vector_store( result = mock_memory._add_to_vector_store(
messages=[{"role": "user", "content": "test"}], messages=[{"role": "user", "content": "test"}], metadata={}, filters={}, infer=True
metadata={},
filters={},
infer=True
) )
# Verify # Verify
assert mock_memory.llm.generate_response.call_count == 2 assert mock_memory.llm.generate_response.call_count == 2
assert result == [] # Should return empty list when no memories processed assert result == [] # Should return empty list when no memories processed
@@ -62,20 +60,14 @@ class TestAddToVectorStoreErrors:
"""Test empty response from LLM during memory actions""" """Test empty response from LLM during memory actions"""
# Setup # Setup
# First call returns valid JSON, second call returns empty string # First call returns valid JSON, second call returns empty string
mock_memory.llm.generate_response.side_effect = [ mock_memory.llm.generate_response.side_effect = ['{"facts": ["test fact"]}', ""]
'{"facts": ["test fact"]}',
""
]
# Execute # Execute
with caplog.at_level(logging.ERROR): with caplog.at_level(logging.ERROR):
result = mock_memory._add_to_vector_store( result = mock_memory._add_to_vector_store(
messages=[{"role": "user", "content": "test"}], messages=[{"role": "user", "content": "test"}], metadata={}, filters={}, infer=True
metadata={},
filters={},
infer=True
) )
# Verify # Verify
assert mock_memory.llm.generate_response.call_count == 2 assert mock_memory.llm.generate_response.call_count == 2
assert result == [] # Should return empty list when no memories processed assert result == [] # Should return empty list when no memories processed
@@ -88,48 +80,39 @@ class TestAsyncAddToVectorStoreErrors:
def mock_async_memory(self, mocker): def mock_async_memory(self, mocker):
"""Fixture for AsyncMemory with mocker-based mocks""" """Fixture for AsyncMemory with mocker-based mocks"""
mock_llm, _ = _setup_mocks(mocker) mock_llm, _ = _setup_mocks(mocker)
memory = AsyncMemory() memory = AsyncMemory()
memory.config = mocker.MagicMock() memory.config = mocker.MagicMock()
memory.config.custom_fact_extraction_prompt = None memory.config.custom_fact_extraction_prompt = None
memory.config.custom_update_memory_prompt = None memory.config.custom_update_memory_prompt = None
memory.api_version = "v1.1" memory.api_version = "v1.1"
return memory return memory
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_empty_llm_response_fact_extraction(self, mock_async_memory, caplog, mocker): async def test_async_empty_llm_response_fact_extraction(self, mock_async_memory, caplog, mocker):
"""Test empty response in AsyncMemory._add_to_vector_store""" """Test empty response in AsyncMemory._add_to_vector_store"""
mocker.patch('mem0.utils.factory.EmbedderFactory.create', return_value=MagicMock()) mocker.patch("mem0.utils.factory.EmbedderFactory.create", return_value=MagicMock())
mock_async_memory.llm.generate_response.return_value = "" mock_async_memory.llm.generate_response.return_value = ""
with caplog.at_level(logging.ERROR): with caplog.at_level(logging.ERROR):
result = await mock_async_memory._add_to_vector_store( result = await mock_async_memory._add_to_vector_store(
messages=[{"role": "user", "content": "test"}], messages=[{"role": "user", "content": "test"}], metadata={}, filters={}, infer=True
metadata={},
filters={},
infer=True
) )
assert result == [] assert result == []
assert "Error in new_retrieved_facts" in caplog.text assert "Error in new_retrieved_facts" in caplog.text
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_empty_llm_response_memory_actions(self, mock_async_memory, caplog, mocker): async def test_async_empty_llm_response_memory_actions(self, mock_async_memory, caplog, mocker):
"""Test empty response in AsyncMemory._add_to_vector_store""" """Test empty response in AsyncMemory._add_to_vector_store"""
mocker.patch('mem0.utils.factory.EmbedderFactory.create', return_value=MagicMock()) mocker.patch("mem0.utils.factory.EmbedderFactory.create", return_value=MagicMock())
mock_async_memory.llm.generate_response.side_effect = [ mock_async_memory.llm.generate_response.side_effect = ['{"facts": ["test fact"]}', ""]
'{"facts": ["test fact"]}',
""
]
with caplog.at_level(logging.ERROR): with caplog.at_level(logging.ERROR):
result = await mock_async_memory._add_to_vector_store( result = await mock_async_memory._add_to_vector_store(
messages=[{"role": "user", "content": "test"}], messages=[{"role": "user", "content": "test"}], metadata={}, filters={}, infer=True
metadata={},
filters={},
infer=True
) )
assert result == [] assert result == []
assert "Invalid JSON response" in caplog.text assert "Invalid JSON response" in caplog.text

View File

@@ -17,11 +17,13 @@ def mock_openai():
@pytest.fixture @pytest.fixture
def memory_instance(): def memory_instance():
with patch("mem0.utils.factory.EmbedderFactory") as mock_embedder, patch( with (
"mem0.utils.factory.VectorStoreFactory" patch("mem0.utils.factory.EmbedderFactory") as mock_embedder,
) as mock_vector_store, patch("mem0.utils.factory.LlmFactory") as mock_llm, patch( patch("mem0.utils.factory.VectorStoreFactory") as mock_vector_store,
"mem0.memory.telemetry.capture_event" patch("mem0.utils.factory.LlmFactory") as mock_llm,
), patch("mem0.memory.graph_memory.MemoryGraph"): patch("mem0.memory.telemetry.capture_event"),
patch("mem0.memory.graph_memory.MemoryGraph"),
):
mock_embedder.create.return_value = Mock() mock_embedder.create.return_value = Mock()
mock_vector_store.create.return_value = Mock() mock_vector_store.create.return_value = Mock()
mock_llm.create.return_value = Mock() mock_llm.create.return_value = Mock()
@@ -30,13 +32,16 @@ def memory_instance():
config.graph_store.config = {"some_config": "value"} config.graph_store.config = {"some_config": "value"}
return Memory(config) return Memory(config)
@pytest.fixture @pytest.fixture
def memory_custom_instance(): def memory_custom_instance():
with patch("mem0.utils.factory.EmbedderFactory") as mock_embedder, patch( with (
"mem0.utils.factory.VectorStoreFactory" patch("mem0.utils.factory.EmbedderFactory") as mock_embedder,
) as mock_vector_store, patch("mem0.utils.factory.LlmFactory") as mock_llm, patch( patch("mem0.utils.factory.VectorStoreFactory") as mock_vector_store,
"mem0.memory.telemetry.capture_event" patch("mem0.utils.factory.LlmFactory") as mock_llm,
), patch("mem0.memory.graph_memory.MemoryGraph"): patch("mem0.memory.telemetry.capture_event"),
patch("mem0.memory.graph_memory.MemoryGraph"),
):
mock_embedder.create.return_value = Mock() mock_embedder.create.return_value = Mock()
mock_vector_store.create.return_value = Mock() mock_vector_store.create.return_value = Mock()
mock_llm.create.return_value = Mock() mock_llm.create.return_value = Mock()
@@ -44,7 +49,7 @@ def memory_custom_instance():
config = MemoryConfig( config = MemoryConfig(
version="v1.1", version="v1.1",
custom_fact_extraction_prompt="custom prompt extracting memory", custom_fact_extraction_prompt="custom prompt extracting memory",
custom_update_memory_prompt="custom prompt determining memory update" custom_update_memory_prompt="custom prompt determining memory update",
) )
config.graph_store.config = {"some_config": "value"} config.graph_store.config = {"some_config": "value"}
return Memory(config) return Memory(config)
@@ -194,7 +199,6 @@ def test_delete_all(memory_instance, version, enable_graph):
assert result["message"] == "Memories deleted successfully!" assert result["message"] == "Memories deleted successfully!"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"version, enable_graph, expected_result", "version, enable_graph, expected_result",
[ [
@@ -242,20 +246,22 @@ def test_get_all(memory_instance, version, enable_graph, expected_result):
memory_instance.graph.get_all.assert_called_once_with({"user_id": "test_user"}, 100) memory_instance.graph.get_all.assert_called_once_with({"user_id": "test_user"}, 100)
else: else:
memory_instance.graph.get_all.assert_not_called() memory_instance.graph.get_all.assert_not_called()
def test_custom_prompts(memory_custom_instance): def test_custom_prompts(memory_custom_instance):
messages = [{"role": "user", "content": "Test message"}] messages = [{"role": "user", "content": "Test message"}]
memory_custom_instance.llm.generate_response = Mock() memory_custom_instance.llm.generate_response = Mock()
with patch("mem0.memory.main.parse_messages", return_value="Test message") as mock_parse_messages: with patch("mem0.memory.main.parse_messages", return_value="Test message") as mock_parse_messages:
with patch("mem0.memory.main.get_update_memory_messages", return_value="custom update memory prompt") as mock_get_update_memory_messages: with patch(
"mem0.memory.main.get_update_memory_messages", return_value="custom update memory prompt"
) as mock_get_update_memory_messages:
memory_custom_instance.add(messages=messages, user_id="test_user") memory_custom_instance.add(messages=messages, user_id="test_user")
## custom prompt ## custom prompt
## ##
mock_parse_messages.assert_called_once_with(messages) mock_parse_messages.assert_called_once_with(messages)
memory_custom_instance.llm.generate_response.assert_any_call( memory_custom_instance.llm.generate_response.assert_any_call(
messages=[ messages=[
{"role": "system", "content": memory_custom_instance.config.custom_fact_extraction_prompt}, {"role": "system", "content": memory_custom_instance.config.custom_fact_extraction_prompt},
@@ -263,12 +269,14 @@ def test_custom_prompts(memory_custom_instance):
], ],
response_format={"type": "json_object"}, response_format={"type": "json_object"},
) )
## custom update memory prompt ## custom update memory prompt
## ##
mock_get_update_memory_messages.assert_called_once_with([],[],memory_custom_instance.config.custom_update_memory_prompt) mock_get_update_memory_messages.assert_called_once_with(
[], [], memory_custom_instance.config.custom_update_memory_prompt
)
memory_custom_instance.llm.generate_response.assert_any_call( memory_custom_instance.llm.generate_response.assert_any_call(
messages=[{"role": "user", "content": mock_get_update_memory_messages.return_value}], messages=[{"role": "user", "content": mock_get_update_memory_messages.return_value}],
response_format={"type": "json_object"}, response_format={"type": "json_object"},
) )

View File

@@ -97,4 +97,4 @@ def test_completions_create_with_system_message(mock_memory_client, mock_litellm
call_args = mock_litellm.completion.call_args[1] call_args = mock_litellm.completion.call_args[1]
assert call_args["messages"][0]["role"] == "system" assert call_args["messages"][0]["role"] == "system"
assert call_args["messages"][0]["content"] == "You are a helpful assistant." assert call_args["messages"][0]["content"] == "You are a helpful assistant."

View File

@@ -13,13 +13,15 @@ from mem0.vector_stores.azure_ai_search import AzureAISearch
# Fixture to patch SearchClient and SearchIndexClient and create an instance of AzureAISearch. # Fixture to patch SearchClient and SearchIndexClient and create an instance of AzureAISearch.
@pytest.fixture @pytest.fixture
def mock_clients(): def mock_clients():
with patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient, \ with (
patch("mem0.vector_stores.azure_ai_search.SearchIndexClient") as MockIndexClient, \ patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient,
patch("mem0.vector_stores.azure_ai_search.AzureKeyCredential") as MockAzureKeyCredential: patch("mem0.vector_stores.azure_ai_search.SearchIndexClient") as MockIndexClient,
patch("mem0.vector_stores.azure_ai_search.AzureKeyCredential") as MockAzureKeyCredential,
):
# Create mocked instances for search and index clients. # Create mocked instances for search and index clients.
mock_search_client = MockSearchClient.return_value mock_search_client = MockSearchClient.return_value
mock_index_client = MockIndexClient.return_value mock_index_client = MockIndexClient.return_value
# Mock the client._client._config.user_agent_policy.add_user_agent # Mock the client._client._config.user_agent_policy.add_user_agent
mock_search_client._client = MagicMock() mock_search_client._client = MagicMock()
mock_search_client._client._config.user_agent_policy.add_user_agent = Mock() mock_search_client._client._config.user_agent_policy.add_user_agent = Mock()
@@ -62,7 +64,7 @@ def azure_ai_search_instance(mock_clients):
api_key="test-api-key", api_key="test-api-key",
embedding_model_dims=3, embedding_model_dims=3,
compression_type="binary", # testing binary quantization option compression_type="binary", # testing binary quantization option
use_float16=True use_float16=True,
) )
# Return instance and clients for verification. # Return instance and clients for verification.
return instance, mock_search_client, mock_index_client return instance, mock_search_client, mock_index_client
@@ -70,21 +72,18 @@ def azure_ai_search_instance(mock_clients):
# --- Tests for AzureAISearchConfig --- # --- Tests for AzureAISearchConfig ---
def test_config_validation_valid(): def test_config_validation_valid():
"""Test valid configurations are accepted.""" """Test valid configurations are accepted."""
# Test minimal configuration # Test minimal configuration
config = AzureAISearchConfig( config = AzureAISearchConfig(service_name="test-service", api_key="test-api-key", embedding_model_dims=768)
service_name="test-service",
api_key="test-api-key",
embedding_model_dims=768
)
assert config.collection_name == "mem0" # Default value assert config.collection_name == "mem0" # Default value
assert config.service_name == "test-service" assert config.service_name == "test-service"
assert config.api_key == "test-api-key" assert config.api_key == "test-api-key"
assert config.embedding_model_dims == 768 assert config.embedding_model_dims == 768
assert config.compression_type is None assert config.compression_type is None
assert config.use_float16 is False assert config.use_float16 is False
# Test with all optional parameters # Test with all optional parameters
config = AzureAISearchConfig( config = AzureAISearchConfig(
collection_name="custom-index", collection_name="custom-index",
@@ -92,7 +91,7 @@ def test_config_validation_valid():
api_key="test-api-key", api_key="test-api-key",
embedding_model_dims=1536, embedding_model_dims=1536,
compression_type="scalar", compression_type="scalar",
use_float16=True use_float16=True,
) )
assert config.collection_name == "custom-index" assert config.collection_name == "custom-index"
assert config.compression_type == "scalar" assert config.compression_type == "scalar"
@@ -106,7 +105,7 @@ def test_config_validation_invalid_compression_type():
service_name="test-service", service_name="test-service",
api_key="test-api-key", api_key="test-api-key",
embedding_model_dims=768, embedding_model_dims=768,
compression_type="invalid-type" # Not a valid option compression_type="invalid-type", # Not a valid option
) )
assert "Invalid compression_type" in str(exc_info.value) assert "Invalid compression_type" in str(exc_info.value)
@@ -118,7 +117,7 @@ def test_config_validation_deprecated_use_compression():
service_name="test-service", service_name="test-service",
api_key="test-api-key", api_key="test-api-key",
embedding_model_dims=768, embedding_model_dims=768,
use_compression=True # Deprecated parameter use_compression=True, # Deprecated parameter
) )
# Fix: Use a partial string match instead of exact match # Fix: Use a partial string match instead of exact match
assert "use_compression" in str(exc_info.value) assert "use_compression" in str(exc_info.value)
@@ -132,7 +131,7 @@ def test_config_validation_extra_fields():
service_name="test-service", service_name="test-service",
api_key="test-api-key", api_key="test-api-key",
embedding_model_dims=768, embedding_model_dims=768,
unknown_parameter="value" # Extra field unknown_parameter="value", # Extra field
) )
assert "Extra fields not allowed" in str(exc_info.value) assert "Extra fields not allowed" in str(exc_info.value)
assert "unknown_parameter" in str(exc_info.value) assert "unknown_parameter" in str(exc_info.value)
@@ -140,30 +139,28 @@ def test_config_validation_extra_fields():
# --- Tests for AzureAISearch initialization --- # --- Tests for AzureAISearch initialization ---
def test_initialization(mock_clients): def test_initialization(mock_clients):
"""Test AzureAISearch initialization with different parameters.""" """Test AzureAISearch initialization with different parameters."""
mock_search_client, mock_index_client, mock_azure_key_credential = mock_clients mock_search_client, mock_index_client, mock_azure_key_credential = mock_clients
# Test with minimal parameters # Test with minimal parameters
instance = AzureAISearch( instance = AzureAISearch(
service_name="test-service", service_name="test-service", collection_name="test-index", api_key="test-api-key", embedding_model_dims=768
collection_name="test-index",
api_key="test-api-key",
embedding_model_dims=768
) )
# Verify initialization parameters # Verify initialization parameters
assert instance.index_name == "test-index" assert instance.index_name == "test-index"
assert instance.collection_name == "test-index" assert instance.collection_name == "test-index"
assert instance.embedding_model_dims == 768 assert instance.embedding_model_dims == 768
assert instance.compression_type == "none" # Default when None is passed assert instance.compression_type == "none" # Default when None is passed
assert instance.use_float16 is False assert instance.use_float16 is False
# Verify client creation # Verify client creation
mock_azure_key_credential.assert_called_with("test-api-key") mock_azure_key_credential.assert_called_with("test-api-key")
assert "mem0" in mock_search_client._client._config.user_agent_policy.add_user_agent.call_args[0] assert "mem0" in mock_search_client._client._config.user_agent_policy.add_user_agent.call_args[0]
assert "mem0" in mock_index_client._client._config.user_agent_policy.add_user_agent.call_args[0] assert "mem0" in mock_index_client._client._config.user_agent_policy.add_user_agent.call_args[0]
# Verify index creation was called # Verify index creation was called
mock_index_client.create_or_update_index.assert_called_once() mock_index_client.create_or_update_index.assert_called_once()
@@ -171,75 +168,75 @@ def test_initialization(mock_clients):
def test_initialization_with_compression_types(mock_clients): def test_initialization_with_compression_types(mock_clients):
"""Test initialization with different compression types.""" """Test initialization with different compression types."""
mock_search_client, mock_index_client, _ = mock_clients mock_search_client, mock_index_client, _ = mock_clients
# Test with scalar compression # Test with scalar compression
instance = AzureAISearch( instance = AzureAISearch(
service_name="test-service", service_name="test-service",
collection_name="scalar-index", collection_name="scalar-index",
api_key="test-api-key", api_key="test-api-key",
embedding_model_dims=768, embedding_model_dims=768,
compression_type="scalar" compression_type="scalar",
) )
assert instance.compression_type == "scalar" assert instance.compression_type == "scalar"
# Capture the index creation call # Capture the index creation call
args, _ = mock_index_client.create_or_update_index.call_args_list[-1] args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
index = args[0] index = args[0]
# Verify scalar compression was configured # Verify scalar compression was configured
assert hasattr(index.vector_search, 'compressions') assert hasattr(index.vector_search, "compressions")
assert len(index.vector_search.compressions) > 0 assert len(index.vector_search.compressions) > 0
assert "ScalarQuantizationCompression" in str(type(index.vector_search.compressions[0])) assert "ScalarQuantizationCompression" in str(type(index.vector_search.compressions[0]))
# Test with binary compression # Test with binary compression
instance = AzureAISearch( instance = AzureAISearch(
service_name="test-service", service_name="test-service",
collection_name="binary-index", collection_name="binary-index",
api_key="test-api-key", api_key="test-api-key",
embedding_model_dims=768, embedding_model_dims=768,
compression_type="binary" compression_type="binary",
) )
assert instance.compression_type == "binary" assert instance.compression_type == "binary"
# Capture the index creation call # Capture the index creation call
args, _ = mock_index_client.create_or_update_index.call_args_list[-1] args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
index = args[0] index = args[0]
# Verify binary compression was configured # Verify binary compression was configured
assert hasattr(index.vector_search, 'compressions') assert hasattr(index.vector_search, "compressions")
assert len(index.vector_search.compressions) > 0 assert len(index.vector_search.compressions) > 0
assert "BinaryQuantizationCompression" in str(type(index.vector_search.compressions[0])) assert "BinaryQuantizationCompression" in str(type(index.vector_search.compressions[0]))
# Test with no compression # Test with no compression
instance = AzureAISearch( instance = AzureAISearch(
service_name="test-service", service_name="test-service",
collection_name="no-compression-index", collection_name="no-compression-index",
api_key="test-api-key", api_key="test-api-key",
embedding_model_dims=768, embedding_model_dims=768,
compression_type=None compression_type=None,
) )
assert instance.compression_type == "none" assert instance.compression_type == "none"
# Capture the index creation call # Capture the index creation call
args, _ = mock_index_client.create_or_update_index.call_args_list[-1] args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
index = args[0] index = args[0]
# Verify no compression was configured # Verify no compression was configured
assert hasattr(index.vector_search, 'compressions') assert hasattr(index.vector_search, "compressions")
assert len(index.vector_search.compressions) == 0 assert len(index.vector_search.compressions) == 0
def test_initialization_with_float_precision(mock_clients): def test_initialization_with_float_precision(mock_clients):
"""Test initialization with different float precision settings.""" """Test initialization with different float precision settings."""
mock_search_client, mock_index_client, _ = mock_clients mock_search_client, mock_index_client, _ = mock_clients
# Test with half precision (float16) # Test with half precision (float16)
instance = AzureAISearch( instance = AzureAISearch(
service_name="test-service", service_name="test-service",
collection_name="float16-index", collection_name="float16-index",
api_key="test-api-key", api_key="test-api-key",
embedding_model_dims=768, embedding_model_dims=768,
use_float16=True use_float16=True,
) )
assert instance.use_float16 is True assert instance.use_float16 is True
# Capture the index creation call # Capture the index creation call
args, _ = mock_index_client.create_or_update_index.call_args_list[-1] args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
index = args[0] index = args[0]
@@ -247,17 +244,17 @@ def test_initialization_with_float_precision(mock_clients):
vector_field = next((f for f in index.fields if f.name == "vector"), None) vector_field = next((f for f in index.fields if f.name == "vector"), None)
assert vector_field is not None assert vector_field is not None
assert "Edm.Half" in vector_field.type assert "Edm.Half" in vector_field.type
# Test with full precision (float32) # Test with full precision (float32)
instance = AzureAISearch( instance = AzureAISearch(
service_name="test-service", service_name="test-service",
collection_name="float32-index", collection_name="float32-index",
api_key="test-api-key", api_key="test-api-key",
embedding_model_dims=768, embedding_model_dims=768,
use_float16=False use_float16=False,
) )
assert instance.use_float16 is False assert instance.use_float16 is False
# Capture the index creation call # Capture the index creation call
args, _ = mock_index_client.create_or_update_index.call_args_list[-1] args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
index = args[0] index = args[0]
@@ -269,21 +266,22 @@ def test_initialization_with_float_precision(mock_clients):
# --- Tests for create_col method --- # --- Tests for create_col method ---
def test_create_col(azure_ai_search_instance): def test_create_col(azure_ai_search_instance):
"""Test the create_col method creates an index with the correct configuration.""" """Test the create_col method creates an index with the correct configuration."""
instance, _, mock_index_client = azure_ai_search_instance instance, _, mock_index_client = azure_ai_search_instance
# create_col is called during initialization, so we check the call that was already made # create_col is called during initialization, so we check the call that was already made
mock_index_client.create_or_update_index.assert_called_once() mock_index_client.create_or_update_index.assert_called_once()
# Verify the index configuration # Verify the index configuration
args, _ = mock_index_client.create_or_update_index.call_args args, _ = mock_index_client.create_or_update_index.call_args
index = args[0] index = args[0]
# Check basic properties # Check basic properties
assert index.name == "test-index" assert index.name == "test-index"
assert len(index.fields) == 6 # id, user_id, run_id, agent_id, vector, payload assert len(index.fields) == 6 # id, user_id, run_id, agent_id, vector, payload
# Check that required fields are present # Check that required fields are present
field_names = [f.name for f in index.fields] field_names = [f.name for f in index.fields]
assert "id" in field_names assert "id" in field_names
@@ -292,22 +290,22 @@ def test_create_col(azure_ai_search_instance):
assert "user_id" in field_names assert "user_id" in field_names
assert "run_id" in field_names assert "run_id" in field_names
assert "agent_id" in field_names assert "agent_id" in field_names
# Check that id is the key field # Check that id is the key field
id_field = next(f for f in index.fields if f.name == "id") id_field = next(f for f in index.fields if f.name == "id")
assert id_field.key is True assert id_field.key is True
# Check vector search configuration # Check vector search configuration
assert index.vector_search is not None assert index.vector_search is not None
assert len(index.vector_search.profiles) == 1 assert len(index.vector_search.profiles) == 1
assert index.vector_search.profiles[0].name == "my-vector-config" assert index.vector_search.profiles[0].name == "my-vector-config"
assert index.vector_search.profiles[0].algorithm_configuration_name == "my-algorithms-config" assert index.vector_search.profiles[0].algorithm_configuration_name == "my-algorithms-config"
# Check algorithms # Check algorithms
assert len(index.vector_search.algorithms) == 1 assert len(index.vector_search.algorithms) == 1
assert index.vector_search.algorithms[0].name == "my-algorithms-config" assert index.vector_search.algorithms[0].name == "my-algorithms-config"
assert "HnswAlgorithmConfiguration" in str(type(index.vector_search.algorithms[0])) assert "HnswAlgorithmConfiguration" in str(type(index.vector_search.algorithms[0]))
# With binary compression and float16, we should have compression configuration # With binary compression and float16, we should have compression configuration
assert len(index.vector_search.compressions) == 1 assert len(index.vector_search.compressions) == 1
assert index.vector_search.compressions[0].compression_name == "myCompression" assert index.vector_search.compressions[0].compression_name == "myCompression"
@@ -317,24 +315,24 @@ def test_create_col(azure_ai_search_instance):
def test_create_col_scalar_compression(mock_clients): def test_create_col_scalar_compression(mock_clients):
"""Test creating a collection with scalar compression.""" """Test creating a collection with scalar compression."""
mock_search_client, mock_index_client, _ = mock_clients mock_search_client, mock_index_client, _ = mock_clients
AzureAISearch( AzureAISearch(
service_name="test-service", service_name="test-service",
collection_name="scalar-index", collection_name="scalar-index",
api_key="test-api-key", api_key="test-api-key",
embedding_model_dims=768, embedding_model_dims=768,
compression_type="scalar" compression_type="scalar",
) )
# Verify the index configuration # Verify the index configuration
args, _ = mock_index_client.create_or_update_index.call_args args, _ = mock_index_client.create_or_update_index.call_args
index = args[0] index = args[0]
# Check compression configuration # Check compression configuration
assert len(index.vector_search.compressions) == 1 assert len(index.vector_search.compressions) == 1
assert index.vector_search.compressions[0].compression_name == "myCompression" assert index.vector_search.compressions[0].compression_name == "myCompression"
assert "ScalarQuantizationCompression" in str(type(index.vector_search.compressions[0])) assert "ScalarQuantizationCompression" in str(type(index.vector_search.compressions[0]))
# Check profile references compression # Check profile references compression
assert index.vector_search.profiles[0].compression_name == "myCompression" assert index.vector_search.profiles[0].compression_name == "myCompression"
@@ -342,28 +340,29 @@ def test_create_col_scalar_compression(mock_clients):
def test_create_col_no_compression(mock_clients): def test_create_col_no_compression(mock_clients):
"""Test creating a collection with no compression.""" """Test creating a collection with no compression."""
mock_search_client, mock_index_client, _ = mock_clients mock_search_client, mock_index_client, _ = mock_clients
AzureAISearch( AzureAISearch(
service_name="test-service", service_name="test-service",
collection_name="no-compression-index", collection_name="no-compression-index",
api_key="test-api-key", api_key="test-api-key",
embedding_model_dims=768, embedding_model_dims=768,
compression_type=None compression_type=None,
) )
# Verify the index configuration # Verify the index configuration
args, _ = mock_index_client.create_or_update_index.call_args args, _ = mock_index_client.create_or_update_index.call_args
index = args[0] index = args[0]
# Check compression configuration - should be empty # Check compression configuration - should be empty
assert len(index.vector_search.compressions) == 0 assert len(index.vector_search.compressions) == 0
# Check profile doesn't reference compression # Check profile doesn't reference compression
assert index.vector_search.profiles[0].compression_name is None assert index.vector_search.profiles[0].compression_name is None
# --- Tests for insert method --- # --- Tests for insert method ---
def test_insert_single(azure_ai_search_instance): def test_insert_single(azure_ai_search_instance):
"""Test inserting a single vector.""" """Test inserting a single vector."""
instance, mock_search_client, _ = azure_ai_search_instance instance, mock_search_client, _ = azure_ai_search_instance
@@ -372,9 +371,7 @@ def test_insert_single(azure_ai_search_instance):
ids = ["doc1"] ids = ["doc1"]
# Fix: Include status_code: 201 in mock response # Fix: Include status_code: 201 in mock response
mock_search_client.upload_documents.return_value = [ mock_search_client.upload_documents.return_value = [{"status": True, "id": "doc1", "status_code": 201}]
{"status": True, "id": "doc1", "status_code": 201}
]
instance.insert(vectors, payloads, ids) instance.insert(vectors, payloads, ids)
@@ -396,35 +393,35 @@ def test_insert_single(azure_ai_search_instance):
def test_insert_multiple(azure_ai_search_instance): def test_insert_multiple(azure_ai_search_instance):
"""Test inserting multiple vectors in one call.""" """Test inserting multiple vectors in one call."""
instance, mock_search_client, _ = azure_ai_search_instance instance, mock_search_client, _ = azure_ai_search_instance
# Create multiple vectors # Create multiple vectors
num_docs = 3 num_docs = 3
vectors = [[float(i)/10, float(i+1)/10, float(i+2)/10] for i in range(num_docs)] vectors = [[float(i) / 10, float(i + 1) / 10, float(i + 2) / 10] for i in range(num_docs)]
payloads = [{"user_id": f"user{i}", "content": f"Test content {i}"} for i in range(num_docs)] payloads = [{"user_id": f"user{i}", "content": f"Test content {i}"} for i in range(num_docs)]
ids = [f"doc{i}" for i in range(num_docs)] ids = [f"doc{i}" for i in range(num_docs)]
# Configure mock to return success for all documents (fix: add status_code 201) # Configure mock to return success for all documents (fix: add status_code 201)
mock_search_client.upload_documents.return_value = [ mock_search_client.upload_documents.return_value = [
{"status": True, "id": id_val, "status_code": 201} for id_val in ids {"status": True, "id": id_val, "status_code": 201} for id_val in ids
] ]
# Insert the documents # Insert the documents
instance.insert(vectors, payloads, ids) instance.insert(vectors, payloads, ids)
# Verify upload_documents was called with correct documents # Verify upload_documents was called with correct documents
mock_search_client.upload_documents.assert_called_once() mock_search_client.upload_documents.assert_called_once()
args, _ = mock_search_client.upload_documents.call_args args, _ = mock_search_client.upload_documents.call_args
documents = args[0] documents = args[0]
# Verify all documents were included # Verify all documents were included
assert len(documents) == num_docs assert len(documents) == num_docs
# Check first document # Check first document
assert documents[0]["id"] == "doc0" assert documents[0]["id"] == "doc0"
assert documents[0]["vector"] == [0.0, 0.1, 0.2] assert documents[0]["vector"] == [0.0, 0.1, 0.2]
assert documents[0]["payload"] == json.dumps(payloads[0]) assert documents[0]["payload"] == json.dumps(payloads[0])
assert documents[0]["user_id"] == "user0" assert documents[0]["user_id"] == "user0"
# Check last document # Check last document
assert documents[2]["id"] == "doc2" assert documents[2]["id"] == "doc2"
assert documents[2]["vector"] == [0.2, 0.3, 0.4] assert documents[2]["vector"] == [0.2, 0.3, 0.4]
@@ -437,9 +434,7 @@ def test_insert_with_error(azure_ai_search_instance):
instance, mock_search_client, _ = azure_ai_search_instance instance, mock_search_client, _ = azure_ai_search_instance
# Configure mock to return an error for one document # Configure mock to return an error for one document
mock_search_client.upload_documents.return_value = [ mock_search_client.upload_documents.return_value = [{"status": False, "id": "doc1", "errorMessage": "Azure error"}]
{"status": False, "id": "doc1", "errorMessage": "Azure error"}
]
vectors = [[0.1, 0.2, 0.3]] vectors = [[0.1, 0.2, 0.3]]
payloads = [{"user_id": "user1"}] payloads = [{"user_id": "user1"}]
@@ -454,7 +449,7 @@ def test_insert_with_error(azure_ai_search_instance):
# Configure mock to return mixed success/failure for multiple documents # Configure mock to return mixed success/failure for multiple documents
mock_search_client.upload_documents.return_value = [ mock_search_client.upload_documents.return_value = [
{"status": True, "id": "doc1"}, # This should not cause failure {"status": True, "id": "doc1"}, # This should not cause failure
{"status": False, "id": "doc2", "errorMessage": "Azure error"} {"status": False, "id": "doc2", "errorMessage": "Azure error"},
] ]
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
@@ -465,8 +460,9 @@ def test_insert_with_error(azure_ai_search_instance):
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
instance.insert(vectors, payloads, ids) instance.insert(vectors, payloads, ids)
assert "Insert failed for document doc2" in str(exc_info.value) or \ assert "Insert failed for document doc2" in str(exc_info.value) or "Insert failed for document doc1" in str(
"Insert failed for document doc1" in str(exc_info.value) exc_info.value
)
def test_insert_with_missing_payload_fields(azure_ai_search_instance): def test_insert_with_missing_payload_fields(azure_ai_search_instance):
@@ -500,23 +496,24 @@ def test_insert_with_missing_payload_fields(azure_ai_search_instance):
def test_insert_with_http_error(azure_ai_search_instance): def test_insert_with_http_error(azure_ai_search_instance):
"""Test insert when Azure client throws an HTTP error.""" """Test insert when Azure client throws an HTTP error."""
instance, mock_search_client, _ = azure_ai_search_instance instance, mock_search_client, _ = azure_ai_search_instance
# Configure mock to raise an HttpResponseError # Configure mock to raise an HttpResponseError
mock_search_client.upload_documents.side_effect = HttpResponseError("Azure service error") mock_search_client.upload_documents.side_effect = HttpResponseError("Azure service error")
vectors = [[0.1, 0.2, 0.3]] vectors = [[0.1, 0.2, 0.3]]
payloads = [{"user_id": "user1"}] payloads = [{"user_id": "user1"}]
ids = ["doc1"] ids = ["doc1"]
# Insert should propagate the HTTP error # Insert should propagate the HTTP error
with pytest.raises(HttpResponseError) as exc_info: with pytest.raises(HttpResponseError) as exc_info:
instance.insert(vectors, payloads, ids) instance.insert(vectors, payloads, ids)
assert "Azure service error" in str(exc_info.value) assert "Azure service error" in str(exc_info.value)
# --- Tests for search method --- # --- Tests for search method ---
def test_search_basic(azure_ai_search_instance): def test_search_basic(azure_ai_search_instance):
"""Test basic vector search without filters.""" """Test basic vector search without filters."""
instance, mock_search_client, _ = azure_ai_search_instance instance, mock_search_client, _ = azure_ai_search_instance
@@ -536,9 +533,7 @@ def test_search_basic(azure_ai_search_instance):
# Search with a vector # Search with a vector
query_text = "test query" # Add a query string query_text = "test query" # Add a query string
query_vector = [0.1, 0.2, 0.3] query_vector = [0.1, 0.2, 0.3]
results = instance.search( results = instance.search(query_text, query_vector, limit=5) # Pass the query string
query_text, query_vector, limit=5
) # Pass the query string
# Verify search was called correctly # Verify search was called correctly
mock_search_client.search.assert_called_once() mock_search_client.search.assert_called_once()

View File

@@ -7,9 +7,7 @@ import dotenv
try: try:
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
except ImportError: except ImportError:
raise ImportError( raise ImportError("Elasticsearch requires extra dependencies. Install with `pip install elasticsearch`") from None
"Elasticsearch requires extra dependencies. Install with `pip install elasticsearch`"
) from None
from mem0.vector_stores.elasticsearch import ElasticsearchDB, OutputData from mem0.vector_stores.elasticsearch import ElasticsearchDB, OutputData
@@ -19,20 +17,20 @@ class TestElasticsearchDB(unittest.TestCase):
def setUpClass(cls): def setUpClass(cls):
# Load environment variables before any test # Load environment variables before any test
dotenv.load_dotenv() dotenv.load_dotenv()
# Save original environment variables # Save original environment variables
cls.original_env = { cls.original_env = {
'ES_URL': os.getenv('ES_URL', 'http://localhost:9200'), "ES_URL": os.getenv("ES_URL", "http://localhost:9200"),
'ES_USERNAME': os.getenv('ES_USERNAME', 'test_user'), "ES_USERNAME": os.getenv("ES_USERNAME", "test_user"),
'ES_PASSWORD': os.getenv('ES_PASSWORD', 'test_password'), "ES_PASSWORD": os.getenv("ES_PASSWORD", "test_password"),
'ES_CLOUD_ID': os.getenv('ES_CLOUD_ID', 'test_cloud_id') "ES_CLOUD_ID": os.getenv("ES_CLOUD_ID", "test_cloud_id"),
} }
# Set test environment variables # Set test environment variables
os.environ['ES_URL'] = 'http://localhost' os.environ["ES_URL"] = "http://localhost"
os.environ['ES_USERNAME'] = 'test_user' os.environ["ES_USERNAME"] = "test_user"
os.environ['ES_PASSWORD'] = 'test_password' os.environ["ES_PASSWORD"] = "test_password"
def setUp(self): def setUp(self):
# Create a mock Elasticsearch client with proper attributes # Create a mock Elasticsearch client with proper attributes
self.client_mock = MagicMock(spec=Elasticsearch) self.client_mock = MagicMock(spec=Elasticsearch)
@@ -41,25 +39,25 @@ class TestElasticsearchDB(unittest.TestCase):
self.client_mock.indices.create = MagicMock() self.client_mock.indices.create = MagicMock()
self.client_mock.indices.delete = MagicMock() self.client_mock.indices.delete = MagicMock()
self.client_mock.indices.get_alias = MagicMock() self.client_mock.indices.get_alias = MagicMock()
# Start patches BEFORE creating ElasticsearchDB instance # Start patches BEFORE creating ElasticsearchDB instance
patcher = patch('mem0.vector_stores.elasticsearch.Elasticsearch', return_value=self.client_mock) patcher = patch("mem0.vector_stores.elasticsearch.Elasticsearch", return_value=self.client_mock)
self.mock_es = patcher.start() self.mock_es = patcher.start()
self.addCleanup(patcher.stop) self.addCleanup(patcher.stop)
# Initialize ElasticsearchDB with test config and auto_create_index=False # Initialize ElasticsearchDB with test config and auto_create_index=False
self.es_db = ElasticsearchDB( self.es_db = ElasticsearchDB(
host=os.getenv('ES_URL'), host=os.getenv("ES_URL"),
port=9200, port=9200,
collection_name="test_collection", collection_name="test_collection",
embedding_model_dims=1536, embedding_model_dims=1536,
user=os.getenv('ES_USERNAME'), user=os.getenv("ES_USERNAME"),
password=os.getenv('ES_PASSWORD'), password=os.getenv("ES_PASSWORD"),
verify_certs=False, verify_certs=False,
use_ssl=False, use_ssl=False,
auto_create_index=False # Disable auto creation for tests auto_create_index=False, # Disable auto creation for tests
) )
# Reset mock counts after initialization # Reset mock counts after initialization
self.client_mock.reset_mock() self.client_mock.reset_mock()
@@ -80,15 +78,15 @@ class TestElasticsearchDB(unittest.TestCase):
# Test when index doesn't exist # Test when index doesn't exist
self.client_mock.indices.exists.return_value = False self.client_mock.indices.exists.return_value = False
self.es_db.create_index() self.es_db.create_index()
# Verify index creation was called with correct settings # Verify index creation was called with correct settings
self.client_mock.indices.create.assert_called_once() self.client_mock.indices.create.assert_called_once()
create_args = self.client_mock.indices.create.call_args[1] create_args = self.client_mock.indices.create.call_args[1]
# Verify basic index settings # Verify basic index settings
self.assertEqual(create_args["index"], "test_collection") self.assertEqual(create_args["index"], "test_collection")
self.assertIn("mappings", create_args["body"]) self.assertIn("mappings", create_args["body"])
# Verify field mappings # Verify field mappings
mappings = create_args["body"]["mappings"]["properties"] mappings = create_args["body"]["mappings"]["properties"]
self.assertEqual(mappings["text"]["type"], "text") self.assertEqual(mappings["text"]["type"], "text")
@@ -97,53 +95,53 @@ class TestElasticsearchDB(unittest.TestCase):
self.assertEqual(mappings["vector"]["index"], True) self.assertEqual(mappings["vector"]["index"], True)
self.assertEqual(mappings["vector"]["similarity"], "cosine") self.assertEqual(mappings["vector"]["similarity"], "cosine")
self.assertEqual(mappings["metadata"]["type"], "object") self.assertEqual(mappings["metadata"]["type"], "object")
# Reset mocks for next test # Reset mocks for next test
self.client_mock.reset_mock() self.client_mock.reset_mock()
# Test when index already exists # Test when index already exists
self.client_mock.indices.exists.return_value = True self.client_mock.indices.exists.return_value = True
self.es_db.create_index() self.es_db.create_index()
# Verify create was not called when index exists # Verify create was not called when index exists
self.client_mock.indices.create.assert_not_called() self.client_mock.indices.create.assert_not_called()
def test_auto_create_index(self): def test_auto_create_index(self):
# Reset mock # Reset mock
self.client_mock.reset_mock() self.client_mock.reset_mock()
# Test with auto_create_index=True # Test with auto_create_index=True
ElasticsearchDB( ElasticsearchDB(
host=os.getenv('ES_URL'), host=os.getenv("ES_URL"),
port=9200, port=9200,
collection_name="test_collection", collection_name="test_collection",
embedding_model_dims=1536, embedding_model_dims=1536,
user=os.getenv('ES_USERNAME'), user=os.getenv("ES_USERNAME"),
password=os.getenv('ES_PASSWORD'), password=os.getenv("ES_PASSWORD"),
verify_certs=False, verify_certs=False,
use_ssl=False, use_ssl=False,
auto_create_index=True auto_create_index=True,
) )
# Verify create_index was called during initialization # Verify create_index was called during initialization
self.client_mock.indices.exists.assert_called_once() self.client_mock.indices.exists.assert_called_once()
# Reset mock # Reset mock
self.client_mock.reset_mock() self.client_mock.reset_mock()
# Test with auto_create_index=False # Test with auto_create_index=False
ElasticsearchDB( ElasticsearchDB(
host=os.getenv('ES_URL'), host=os.getenv("ES_URL"),
port=9200, port=9200,
collection_name="test_collection", collection_name="test_collection",
embedding_model_dims=1536, embedding_model_dims=1536,
user=os.getenv('ES_USERNAME'), user=os.getenv("ES_USERNAME"),
password=os.getenv('ES_PASSWORD'), password=os.getenv("ES_PASSWORD"),
verify_certs=False, verify_certs=False,
use_ssl=False, use_ssl=False,
auto_create_index=False auto_create_index=False,
) )
# Verify create_index was not called during initialization # Verify create_index was not called during initialization
self.client_mock.indices.exists.assert_not_called() self.client_mock.indices.exists.assert_not_called()
@@ -152,17 +150,17 @@ class TestElasticsearchDB(unittest.TestCase):
vectors = [[0.1] * 1536, [0.2] * 1536] vectors = [[0.1] * 1536, [0.2] * 1536]
payloads = [{"key1": "value1"}, {"key2": "value2"}] payloads = [{"key1": "value1"}, {"key2": "value2"}]
ids = ["id1", "id2"] ids = ["id1", "id2"]
# Mock bulk operation # Mock bulk operation
with patch('mem0.vector_stores.elasticsearch.bulk') as mock_bulk: with patch("mem0.vector_stores.elasticsearch.bulk") as mock_bulk:
mock_bulk.return_value = (2, []) # Simulate successful bulk insert mock_bulk.return_value = (2, []) # Simulate successful bulk insert
# Perform insert # Perform insert
results = self.es_db.insert(vectors=vectors, payloads=payloads, ids=ids) results = self.es_db.insert(vectors=vectors, payloads=payloads, ids=ids)
# Verify bulk was called # Verify bulk was called
mock_bulk.assert_called_once() mock_bulk.assert_called_once()
# Verify bulk actions format # Verify bulk actions format
actions = mock_bulk.call_args[0][1] actions = mock_bulk.call_args[0][1]
self.assertEqual(len(actions), 2) self.assertEqual(len(actions), 2)
@@ -170,7 +168,7 @@ class TestElasticsearchDB(unittest.TestCase):
self.assertEqual(actions[0]["_id"], "id1") self.assertEqual(actions[0]["_id"], "id1")
self.assertEqual(actions[0]["_source"]["vector"], vectors[0]) self.assertEqual(actions[0]["_source"]["vector"], vectors[0])
self.assertEqual(actions[0]["_source"]["metadata"], payloads[0]) self.assertEqual(actions[0]["_source"]["metadata"], payloads[0])
# Verify returned objects # Verify returned objects
self.assertEqual(len(results), 2) self.assertEqual(len(results), 2)
self.assertIsInstance(results[0], OutputData) self.assertIsInstance(results[0], OutputData)
@@ -182,14 +180,7 @@ class TestElasticsearchDB(unittest.TestCase):
mock_response = { mock_response = {
"hits": { "hits": {
"hits": [ "hits": [
{ {"_id": "id1", "_score": 0.8, "_source": {"vector": [0.1] * 1536, "metadata": {"key1": "value1"}}}
"_id": "id1",
"_score": 0.8,
"_source": {
"vector": [0.1] * 1536,
"metadata": {"key1": "value1"}
}
}
] ]
} }
} }
@@ -206,7 +197,7 @@ class TestElasticsearchDB(unittest.TestCase):
# Verify search parameters # Verify search parameters
self.assertEqual(search_args["index"], "test_collection") self.assertEqual(search_args["index"], "test_collection")
body = search_args["body"] body = search_args["body"]
# Verify KNN query structure # Verify KNN query structure
self.assertIn("knn", body) self.assertIn("knn", body)
self.assertEqual(body["knn"]["field"], "vector") self.assertEqual(body["knn"]["field"], "vector")
@@ -235,29 +226,24 @@ class TestElasticsearchDB(unittest.TestCase):
self.es_db.custom_search_query.assert_called_once_with(vectors, limit, filters) self.es_db.custom_search_query.assert_called_once_with(vectors, limit, filters)
# Verify custom search query was used # Verify custom search query was used
self.client_mock.search.assert_called_once_with(index=self.es_db.collection_name, body={"custom_key": "custom_value"}) self.client_mock.search.assert_called_once_with(
index=self.es_db.collection_name, body={"custom_key": "custom_value"}
)
def test_get(self): def test_get(self):
# Mock get response with correct structure # Mock get response with correct structure
mock_response = { mock_response = {
"_id": "id1", "_id": "id1",
"_source": { "_source": {"vector": [0.1] * 1536, "metadata": {"key": "value"}, "text": "sample text"},
"vector": [0.1] * 1536,
"metadata": {"key": "value"},
"text": "sample text"
}
} }
self.client_mock.get.return_value = mock_response self.client_mock.get.return_value = mock_response
# Perform get # Perform get
result = self.es_db.get(vector_id="id1") result = self.es_db.get(vector_id="id1")
# Verify get call # Verify get call
self.client_mock.get.assert_called_once_with( self.client_mock.get.assert_called_once_with(index="test_collection", id="id1")
index="test_collection",
id="id1"
)
# Verify result # Verify result
self.assertIsNotNone(result) self.assertIsNotNone(result)
self.assertEqual(result.id, "id1") self.assertEqual(result.id, "id1")
@@ -267,7 +253,7 @@ class TestElasticsearchDB(unittest.TestCase):
def test_get_not_found(self): def test_get_not_found(self):
# Mock get raising exception # Mock get raising exception
self.client_mock.get.side_effect = Exception("Not found") self.client_mock.get.side_effect = Exception("Not found")
# Verify get returns None when document not found # Verify get returns None when document not found
result = self.es_db.get(vector_id="nonexistent") result = self.es_db.get(vector_id="nonexistent")
self.assertIsNone(result) self.assertIsNone(result)
@@ -277,33 +263,19 @@ class TestElasticsearchDB(unittest.TestCase):
mock_response = { mock_response = {
"hits": { "hits": {
"hits": [ "hits": [
{ {"_id": "id1", "_source": {"vector": [0.1] * 1536, "metadata": {"key1": "value1"}}, "_score": 1.0},
"_id": "id1", {"_id": "id2", "_source": {"vector": [0.2] * 1536, "metadata": {"key2": "value2"}}, "_score": 0.8},
"_source": {
"vector": [0.1] * 1536,
"metadata": {"key1": "value1"}
},
"_score": 1.0
},
{
"_id": "id2",
"_source": {
"vector": [0.2] * 1536,
"metadata": {"key2": "value2"}
},
"_score": 0.8
}
] ]
} }
} }
self.client_mock.search.return_value = mock_response self.client_mock.search.return_value = mock_response
# Perform list operation # Perform list operation
results = self.es_db.list(limit=10) results = self.es_db.list(limit=10)
# Verify search call # Verify search call
self.client_mock.search.assert_called_once() self.client_mock.search.assert_called_once()
# Verify results # Verify results
self.assertEqual(len(results), 1) # Outer list self.assertEqual(len(results), 1) # Outer list
self.assertEqual(len(results[0]), 2) # Inner list self.assertEqual(len(results[0]), 2) # Inner list
@@ -316,30 +288,24 @@ class TestElasticsearchDB(unittest.TestCase):
def test_delete(self): def test_delete(self):
# Perform delete # Perform delete
self.es_db.delete(vector_id="id1") self.es_db.delete(vector_id="id1")
# Verify delete call # Verify delete call
self.client_mock.delete.assert_called_once_with( self.client_mock.delete.assert_called_once_with(index="test_collection", id="id1")
index="test_collection",
id="id1"
)
def test_list_cols(self): def test_list_cols(self):
# Mock indices response # Mock indices response
mock_indices = {"index1": {}, "index2": {}} mock_indices = {"index1": {}, "index2": {}}
self.client_mock.indices.get_alias.return_value = mock_indices self.client_mock.indices.get_alias.return_value = mock_indices
# Get collections # Get collections
result = self.es_db.list_cols() result = self.es_db.list_cols()
# Verify result # Verify result
self.assertEqual(result, ["index1", "index2"]) self.assertEqual(result, ["index1", "index2"])
def test_delete_col(self): def test_delete_col(self):
# Delete collection # Delete collection
self.es_db.delete_col() self.es_db.delete_col()
# Verify delete call # Verify delete call
self.client_mock.indices.delete.assert_called_once_with( self.client_mock.indices.delete.assert_called_once_with(index="test_collection")
index="test_collection"
)

View File

@@ -21,9 +21,9 @@ def mock_faiss_index():
def faiss_instance(mock_faiss_index): def faiss_instance(mock_faiss_index):
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
# Mock the faiss index creation # Mock the faiss index creation
with patch('faiss.IndexFlatL2', return_value=mock_faiss_index): with patch("faiss.IndexFlatL2", return_value=mock_faiss_index):
# Mock the faiss.write_index function # Mock the faiss.write_index function
with patch('faiss.write_index'): with patch("faiss.write_index"):
# Create a FAISS instance with a temporary directory # Create a FAISS instance with a temporary directory
faiss_store = FAISS( faiss_store = FAISS(
collection_name="test_collection", collection_name="test_collection",
@@ -37,14 +37,14 @@ def faiss_instance(mock_faiss_index):
def test_create_col(faiss_instance, mock_faiss_index): def test_create_col(faiss_instance, mock_faiss_index):
# Test creating a collection with euclidean distance # Test creating a collection with euclidean distance
with patch('faiss.IndexFlatL2', return_value=mock_faiss_index) as mock_index_flat_l2: with patch("faiss.IndexFlatL2", return_value=mock_faiss_index) as mock_index_flat_l2:
with patch('faiss.write_index'): with patch("faiss.write_index"):
faiss_instance.create_col(name="new_collection") faiss_instance.create_col(name="new_collection")
mock_index_flat_l2.assert_called_once_with(faiss_instance.embedding_model_dims) mock_index_flat_l2.assert_called_once_with(faiss_instance.embedding_model_dims)
# Test creating a collection with inner product distance # Test creating a collection with inner product distance
with patch('faiss.IndexFlatIP', return_value=mock_faiss_index) as mock_index_flat_ip: with patch("faiss.IndexFlatIP", return_value=mock_faiss_index) as mock_index_flat_ip:
with patch('faiss.write_index'): with patch("faiss.write_index"):
faiss_instance.create_col(name="new_collection", distance="inner_product") faiss_instance.create_col(name="new_collection", distance="inner_product")
mock_index_flat_ip.assert_called_once_with(faiss_instance.embedding_model_dims) mock_index_flat_ip.assert_called_once_with(faiss_instance.embedding_model_dims)
@@ -54,21 +54,21 @@ def test_insert(faiss_instance, mock_faiss_index):
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
payloads = [{"name": "vector1"}, {"name": "vector2"}] payloads = [{"name": "vector1"}, {"name": "vector2"}]
ids = ["id1", "id2"] ids = ["id1", "id2"]
# Mock the numpy array conversion # Mock the numpy array conversion
with patch('numpy.array', return_value=np.array(vectors, dtype=np.float32)) as mock_np_array: with patch("numpy.array", return_value=np.array(vectors, dtype=np.float32)) as mock_np_array:
# Mock index.add # Mock index.add
mock_faiss_index.add.return_value = None mock_faiss_index.add.return_value = None
# Call insert # Call insert
faiss_instance.insert(vectors=vectors, payloads=payloads, ids=ids) faiss_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
# Verify numpy.array was called # Verify numpy.array was called
mock_np_array.assert_called_once_with(vectors, dtype=np.float32) mock_np_array.assert_called_once_with(vectors, dtype=np.float32)
# Verify index.add was called # Verify index.add was called
mock_faiss_index.add.assert_called_once() mock_faiss_index.add.assert_called_once()
# Verify docstore and index_to_id were updated # Verify docstore and index_to_id were updated
assert faiss_instance.docstore["id1"] == {"name": "vector1"} assert faiss_instance.docstore["id1"] == {"name": "vector1"}
assert faiss_instance.docstore["id2"] == {"name": "vector2"} assert faiss_instance.docstore["id2"] == {"name": "vector2"}
@@ -79,39 +79,36 @@ def test_insert(faiss_instance, mock_faiss_index):
def test_search(faiss_instance, mock_faiss_index): def test_search(faiss_instance, mock_faiss_index):
# Prepare test data # Prepare test data
query_vector = [0.1, 0.2, 0.3] query_vector = [0.1, 0.2, 0.3]
# Setup the docstore and index_to_id mapping # Setup the docstore and index_to_id mapping
faiss_instance.docstore = { faiss_instance.docstore = {"id1": {"name": "vector1"}, "id2": {"name": "vector2"}}
"id1": {"name": "vector1"},
"id2": {"name": "vector2"}
}
faiss_instance.index_to_id = {0: "id1", 1: "id2"} faiss_instance.index_to_id = {0: "id1", 1: "id2"}
# First, create the mock for the search return values # First, create the mock for the search return values
search_scores = np.array([[0.9, 0.8]]) search_scores = np.array([[0.9, 0.8]])
search_indices = np.array([[0, 1]]) search_indices = np.array([[0, 1]])
mock_faiss_index.search.return_value = (search_scores, search_indices) mock_faiss_index.search.return_value = (search_scores, search_indices)
# Then patch numpy.array only for the query vector conversion # Then patch numpy.array only for the query vector conversion
with patch('numpy.array') as mock_np_array: with patch("numpy.array") as mock_np_array:
mock_np_array.return_value = np.array(query_vector, dtype=np.float32) mock_np_array.return_value = np.array(query_vector, dtype=np.float32)
# Then patch _parse_output to return the expected results # Then patch _parse_output to return the expected results
expected_results = [ expected_results = [
OutputData(id="id1", score=0.9, payload={"name": "vector1"}), OutputData(id="id1", score=0.9, payload={"name": "vector1"}),
OutputData(id="id2", score=0.8, payload={"name": "vector2"}) OutputData(id="id2", score=0.8, payload={"name": "vector2"}),
] ]
with patch.object(faiss_instance, '_parse_output', return_value=expected_results): with patch.object(faiss_instance, "_parse_output", return_value=expected_results):
# Call search # Call search
results = faiss_instance.search(query="test query", vectors=query_vector, limit=2) results = faiss_instance.search(query="test query", vectors=query_vector, limit=2)
# Verify numpy.array was called (but we don't check exact call arguments since it's complex) # Verify numpy.array was called (but we don't check exact call arguments since it's complex)
assert mock_np_array.called assert mock_np_array.called
# Verify index.search was called # Verify index.search was called
mock_faiss_index.search.assert_called_once() mock_faiss_index.search.assert_called_once()
# Verify results # Verify results
assert len(results) == 2 assert len(results) == 2
assert results[0].id == "id1" assert results[0].id == "id1"
@@ -125,47 +122,41 @@ def test_search(faiss_instance, mock_faiss_index):
def test_search_with_filters(faiss_instance, mock_faiss_index): def test_search_with_filters(faiss_instance, mock_faiss_index):
# Prepare test data # Prepare test data
query_vector = [0.1, 0.2, 0.3] query_vector = [0.1, 0.2, 0.3]
# Setup the docstore and index_to_id mapping # Setup the docstore and index_to_id mapping
faiss_instance.docstore = { faiss_instance.docstore = {"id1": {"name": "vector1", "category": "A"}, "id2": {"name": "vector2", "category": "B"}}
"id1": {"name": "vector1", "category": "A"},
"id2": {"name": "vector2", "category": "B"}
}
faiss_instance.index_to_id = {0: "id1", 1: "id2"} faiss_instance.index_to_id = {0: "id1", 1: "id2"}
# First set up the search return values # First set up the search return values
search_scores = np.array([[0.9, 0.8]]) search_scores = np.array([[0.9, 0.8]])
search_indices = np.array([[0, 1]]) search_indices = np.array([[0, 1]])
mock_faiss_index.search.return_value = (search_scores, search_indices) mock_faiss_index.search.return_value = (search_scores, search_indices)
# Patch numpy.array for query vector conversion # Patch numpy.array for query vector conversion
with patch('numpy.array') as mock_np_array: with patch("numpy.array") as mock_np_array:
mock_np_array.return_value = np.array(query_vector, dtype=np.float32) mock_np_array.return_value = np.array(query_vector, dtype=np.float32)
# Directly mock the _parse_output method to return our expected values # Directly mock the _parse_output method to return our expected values
# We're simulating that _parse_output filters to just the first result # We're simulating that _parse_output filters to just the first result
all_results = [ all_results = [
OutputData(id="id1", score=0.9, payload={"name": "vector1", "category": "A"}), OutputData(id="id1", score=0.9, payload={"name": "vector1", "category": "A"}),
OutputData(id="id2", score=0.8, payload={"name": "vector2", "category": "B"}) OutputData(id="id2", score=0.8, payload={"name": "vector2", "category": "B"}),
] ]
# Replace the _apply_filters method to handle our test case # Replace the _apply_filters method to handle our test case
with patch.object(faiss_instance, '_parse_output', return_value=all_results): with patch.object(faiss_instance, "_parse_output", return_value=all_results):
with patch.object(faiss_instance, '_apply_filters', side_effect=lambda p, f: p.get("category") == "A"): with patch.object(faiss_instance, "_apply_filters", side_effect=lambda p, f: p.get("category") == "A"):
# Call search with filters # Call search with filters
results = faiss_instance.search( results = faiss_instance.search(
query="test query", query="test query", vectors=query_vector, limit=2, filters={"category": "A"}
vectors=query_vector,
limit=2,
filters={"category": "A"}
) )
# Verify numpy.array was called # Verify numpy.array was called
assert mock_np_array.called assert mock_np_array.called
# Verify index.search was called # Verify index.search was called
mock_faiss_index.search.assert_called_once() mock_faiss_index.search.assert_called_once()
# Verify filtered results - since we've mocked everything, # Verify filtered results - since we've mocked everything,
# we should get just the result we want # we should get just the result we want
assert len(results) == 1 assert len(results) == 1
@@ -176,15 +167,12 @@ def test_search_with_filters(faiss_instance, mock_faiss_index):
def test_delete(faiss_instance): def test_delete(faiss_instance):
# Setup the docstore and index_to_id mapping # Setup the docstore and index_to_id mapping
faiss_instance.docstore = { faiss_instance.docstore = {"id1": {"name": "vector1"}, "id2": {"name": "vector2"}}
"id1": {"name": "vector1"},
"id2": {"name": "vector2"}
}
faiss_instance.index_to_id = {0: "id1", 1: "id2"} faiss_instance.index_to_id = {0: "id1", 1: "id2"}
# Call delete # Call delete
faiss_instance.delete(vector_id="id1") faiss_instance.delete(vector_id="id1")
# Verify the vector was removed from docstore and index_to_id # Verify the vector was removed from docstore and index_to_id
assert "id1" not in faiss_instance.docstore assert "id1" not in faiss_instance.docstore
assert 0 not in faiss_instance.index_to_id assert 0 not in faiss_instance.index_to_id
@@ -194,23 +182,20 @@ def test_delete(faiss_instance):
def test_update(faiss_instance, mock_faiss_index): def test_update(faiss_instance, mock_faiss_index):
# Setup the docstore and index_to_id mapping # Setup the docstore and index_to_id mapping
faiss_instance.docstore = { faiss_instance.docstore = {"id1": {"name": "vector1"}, "id2": {"name": "vector2"}}
"id1": {"name": "vector1"},
"id2": {"name": "vector2"}
}
faiss_instance.index_to_id = {0: "id1", 1: "id2"} faiss_instance.index_to_id = {0: "id1", 1: "id2"}
# Test updating payload only # Test updating payload only
faiss_instance.update(vector_id="id1", payload={"name": "updated_vector1"}) faiss_instance.update(vector_id="id1", payload={"name": "updated_vector1"})
assert faiss_instance.docstore["id1"] == {"name": "updated_vector1"} assert faiss_instance.docstore["id1"] == {"name": "updated_vector1"}
# Test updating vector # Test updating vector
# This requires mocking the delete and insert methods # This requires mocking the delete and insert methods
with patch.object(faiss_instance, 'delete') as mock_delete: with patch.object(faiss_instance, "delete") as mock_delete:
with patch.object(faiss_instance, 'insert') as mock_insert: with patch.object(faiss_instance, "insert") as mock_insert:
new_vector = [0.7, 0.8, 0.9] new_vector = [0.7, 0.8, 0.9]
faiss_instance.update(vector_id="id2", vector=new_vector) faiss_instance.update(vector_id="id2", vector=new_vector)
# Verify delete and insert were called # Verify delete and insert were called
# Match the actual call signature (positional arg instead of keyword) # Match the actual call signature (positional arg instead of keyword)
mock_delete.assert_called_once_with("id2") mock_delete.assert_called_once_with("id2")
@@ -219,17 +204,14 @@ def test_update(faiss_instance, mock_faiss_index):
def test_get(faiss_instance): def test_get(faiss_instance):
# Setup the docstore # Setup the docstore
faiss_instance.docstore = { faiss_instance.docstore = {"id1": {"name": "vector1"}, "id2": {"name": "vector2"}}
"id1": {"name": "vector1"},
"id2": {"name": "vector2"}
}
# Test getting an existing vector # Test getting an existing vector
result = faiss_instance.get(vector_id="id1") result = faiss_instance.get(vector_id="id1")
assert result.id == "id1" assert result.id == "id1"
assert result.payload == {"name": "vector1"} assert result.payload == {"name": "vector1"}
assert result.score is None assert result.score is None
# Test getting a non-existent vector # Test getting a non-existent vector
result = faiss_instance.get(vector_id="id3") result = faiss_instance.get(vector_id="id3")
assert result is None assert result is None
@@ -240,18 +222,18 @@ def test_list(faiss_instance):
faiss_instance.docstore = { faiss_instance.docstore = {
"id1": {"name": "vector1", "category": "A"}, "id1": {"name": "vector1", "category": "A"},
"id2": {"name": "vector2", "category": "B"}, "id2": {"name": "vector2", "category": "B"},
"id3": {"name": "vector3", "category": "A"} "id3": {"name": "vector3", "category": "A"},
} }
# Test listing all vectors # Test listing all vectors
results = faiss_instance.list() results = faiss_instance.list()
# Fix the expected result - the list method returns a list of lists # Fix the expected result - the list method returns a list of lists
assert len(results[0]) == 3 assert len(results[0]) == 3
# Test listing with a limit # Test listing with a limit
results = faiss_instance.list(limit=2) results = faiss_instance.list(limit=2)
assert len(results[0]) == 2 assert len(results[0]) == 2
# Test listing with filters # Test listing with filters
results = faiss_instance.list(filters={"category": "A"}) results = faiss_instance.list(filters={"category": "A"})
assert len(results[0]) == 2 assert len(results[0]) == 2
@@ -263,10 +245,10 @@ def test_col_info(faiss_instance, mock_faiss_index):
# Mock index attributes # Mock index attributes
mock_faiss_index.ntotal = 5 mock_faiss_index.ntotal = 5
mock_faiss_index.d = 128 mock_faiss_index.d = 128
# Get collection info # Get collection info
info = faiss_instance.col_info() info = faiss_instance.col_info()
# Verify the returned info # Verify the returned info
assert info["name"] == "test_collection" assert info["name"] == "test_collection"
assert info["count"] == 5 assert info["count"] == 5
@@ -276,14 +258,14 @@ def test_col_info(faiss_instance, mock_faiss_index):
def test_delete_col(faiss_instance): def test_delete_col(faiss_instance):
# Mock the os.remove function # Mock the os.remove function
with patch('os.remove') as mock_remove: with patch("os.remove") as mock_remove:
with patch('os.path.exists', return_value=True): with patch("os.path.exists", return_value=True):
# Call delete_col # Call delete_col
faiss_instance.delete_col() faiss_instance.delete_col()
# Verify os.remove was called twice (for index and docstore files) # Verify os.remove was called twice (for index and docstore files)
assert mock_remove.call_count == 2 assert mock_remove.call_count == 2
# Verify the internal state was reset # Verify the internal state was reset
assert faiss_instance.index is None assert faiss_instance.index is None
assert faiss_instance.docstore == {} assert faiss_instance.docstore == {}
@@ -293,17 +275,17 @@ def test_delete_col(faiss_instance):
def test_normalize_L2(faiss_instance, mock_faiss_index): def test_normalize_L2(faiss_instance, mock_faiss_index):
# Setup a FAISS instance with normalize_L2=True # Setup a FAISS instance with normalize_L2=True
faiss_instance.normalize_L2 = True faiss_instance.normalize_L2 = True
# Prepare test data # Prepare test data
vectors = [[0.1, 0.2, 0.3]] vectors = [[0.1, 0.2, 0.3]]
# Mock numpy array conversion # Mock numpy array conversion
# Mock numpy array conversion # Mock numpy array conversion
with patch('numpy.array', return_value=np.array(vectors, dtype=np.float32)): with patch("numpy.array", return_value=np.array(vectors, dtype=np.float32)):
# Mock faiss.normalize_L2 # Mock faiss.normalize_L2
with patch('faiss.normalize_L2') as mock_normalize: with patch("faiss.normalize_L2") as mock_normalize:
# Call insert # Call insert
faiss_instance.insert(vectors=vectors, ids=["id1"]) faiss_instance.insert(vectors=vectors, ids=["id1"])
# Verify faiss.normalize_L2 was called # Verify faiss.normalize_L2 was called
mock_normalize.assert_called_once() mock_normalize.assert_called_once()

View File

@@ -11,11 +11,13 @@ def mock_langchain_client():
with patch("langchain_community.vectorstores.VectorStore") as mock_client: with patch("langchain_community.vectorstores.VectorStore") as mock_client:
yield mock_client yield mock_client
@pytest.fixture @pytest.fixture
def langchain_instance(mock_langchain_client): def langchain_instance(mock_langchain_client):
mock_client = Mock(spec=VectorStore) mock_client = Mock(spec=VectorStore)
return Langchain(client=mock_client, collection_name="test_collection") return Langchain(client=mock_client, collection_name="test_collection")
def test_insert_vectors(langchain_instance): def test_insert_vectors(langchain_instance):
# Test data # Test data
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
@@ -25,48 +27,31 @@ def test_insert_vectors(langchain_instance):
# Test with add_embeddings method # Test with add_embeddings method
langchain_instance.client.add_embeddings = Mock() langchain_instance.client.add_embeddings = Mock()
langchain_instance.insert(vectors=vectors, payloads=payloads, ids=ids) langchain_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
langchain_instance.client.add_embeddings.assert_called_once_with( langchain_instance.client.add_embeddings.assert_called_once_with(embeddings=vectors, metadatas=payloads, ids=ids)
embeddings=vectors,
metadatas=payloads,
ids=ids
)
# Test with add_texts method # Test with add_texts method
delattr(langchain_instance.client, "add_embeddings") # Remove attribute completely delattr(langchain_instance.client, "add_embeddings") # Remove attribute completely
langchain_instance.client.add_texts = Mock() langchain_instance.client.add_texts = Mock()
langchain_instance.insert(vectors=vectors, payloads=payloads, ids=ids) langchain_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
langchain_instance.client.add_texts.assert_called_once_with( langchain_instance.client.add_texts.assert_called_once_with(texts=["text1", "text2"], metadatas=payloads, ids=ids)
texts=["text1", "text2"],
metadatas=payloads,
ids=ids
)
# Test with empty payloads # Test with empty payloads
langchain_instance.client.add_texts.reset_mock() langchain_instance.client.add_texts.reset_mock()
langchain_instance.insert(vectors=vectors, payloads=None, ids=ids) langchain_instance.insert(vectors=vectors, payloads=None, ids=ids)
langchain_instance.client.add_texts.assert_called_once_with( langchain_instance.client.add_texts.assert_called_once_with(texts=["", ""], metadatas=None, ids=ids)
texts=["", ""],
metadatas=None,
ids=ids
)
def test_search_vectors(langchain_instance): def test_search_vectors(langchain_instance):
# Mock search results # Mock search results
mock_docs = [ mock_docs = [Mock(metadata={"name": "vector1"}, id="id1"), Mock(metadata={"name": "vector2"}, id="id2")]
Mock(metadata={"name": "vector1"}, id="id1"),
Mock(metadata={"name": "vector2"}, id="id2")
]
langchain_instance.client.similarity_search_by_vector.return_value = mock_docs langchain_instance.client.similarity_search_by_vector.return_value = mock_docs
# Test search without filters # Test search without filters
vectors = [[0.1, 0.2, 0.3]] vectors = [[0.1, 0.2, 0.3]]
results = langchain_instance.search(query="", vectors=vectors, limit=2) results = langchain_instance.search(query="", vectors=vectors, limit=2)
langchain_instance.client.similarity_search_by_vector.assert_called_once_with( langchain_instance.client.similarity_search_by_vector.assert_called_once_with(embedding=vectors, k=2)
embedding=vectors,
k=2
)
assert len(results) == 2 assert len(results) == 2
assert results[0].id == "id1" assert results[0].id == "id1"
assert results[0].payload == {"name": "vector1"} assert results[0].payload == {"name": "vector1"}
@@ -76,11 +61,8 @@ def test_search_vectors(langchain_instance):
# Test search with filters # Test search with filters
filters = {"name": "vector1"} filters = {"name": "vector1"}
langchain_instance.search(query="", vectors=vectors, limit=2, filters=filters) langchain_instance.search(query="", vectors=vectors, limit=2, filters=filters)
langchain_instance.client.similarity_search_by_vector.assert_called_with( langchain_instance.client.similarity_search_by_vector.assert_called_with(embedding=vectors, k=2, filter=filters)
embedding=vectors,
k=2,
filter=filters
)
def test_get_vector(langchain_instance): def test_get_vector(langchain_instance):
# Mock get result # Mock get result
@@ -90,7 +72,7 @@ def test_get_vector(langchain_instance):
# Test get existing vector # Test get existing vector
result = langchain_instance.get("id1") result = langchain_instance.get("id1")
langchain_instance.client.get_by_ids.assert_called_once_with(["id1"]) langchain_instance.client.get_by_ids.assert_called_once_with(["id1"])
assert result is not None assert result is not None
assert result.id == "id1" assert result.id == "id1"
assert result.payload == {"name": "vector1"} assert result.payload == {"name": "vector1"}

View File

@@ -8,9 +8,7 @@ import pytest
try: try:
from opensearchpy import AWSV4SignerAuth, OpenSearch from opensearchpy import AWSV4SignerAuth, OpenSearch
except ImportError: except ImportError:
raise ImportError( raise ImportError("OpenSearch requires extra dependencies. Install with `pip install opensearch-py`") from None
"OpenSearch requires extra dependencies. Install with `pip install opensearch-py`"
) from None
from mem0.vector_stores.opensearch import OpenSearchDB from mem0.vector_stores.opensearch import OpenSearchDB
@@ -20,13 +18,13 @@ class TestOpenSearchDB(unittest.TestCase):
def setUpClass(cls): def setUpClass(cls):
dotenv.load_dotenv() dotenv.load_dotenv()
cls.original_env = { cls.original_env = {
'OS_URL': os.getenv('OS_URL', 'http://localhost:9200'), "OS_URL": os.getenv("OS_URL", "http://localhost:9200"),
'OS_USERNAME': os.getenv('OS_USERNAME', 'test_user'), "OS_USERNAME": os.getenv("OS_USERNAME", "test_user"),
'OS_PASSWORD': os.getenv('OS_PASSWORD', 'test_password') "OS_PASSWORD": os.getenv("OS_PASSWORD", "test_password"),
} }
os.environ['OS_URL'] = 'http://localhost' os.environ["OS_URL"] = "http://localhost"
os.environ['OS_USERNAME'] = 'test_user' os.environ["OS_USERNAME"] = "test_user"
os.environ['OS_PASSWORD'] = 'test_password' os.environ["OS_PASSWORD"] = "test_password"
def setUp(self): def setUp(self):
self.client_mock = MagicMock(spec=OpenSearch) self.client_mock = MagicMock(spec=OpenSearch)
@@ -40,19 +38,19 @@ class TestOpenSearchDB(unittest.TestCase):
self.client_mock.delete = MagicMock() self.client_mock.delete = MagicMock()
self.client_mock.search = MagicMock() self.client_mock.search = MagicMock()
patcher = patch('mem0.vector_stores.opensearch.OpenSearch', return_value=self.client_mock) patcher = patch("mem0.vector_stores.opensearch.OpenSearch", return_value=self.client_mock)
self.mock_os = patcher.start() self.mock_os = patcher.start()
self.addCleanup(patcher.stop) self.addCleanup(patcher.stop)
self.os_db = OpenSearchDB( self.os_db = OpenSearchDB(
host=os.getenv('OS_URL'), host=os.getenv("OS_URL"),
port=9200, port=9200,
collection_name="test_collection", collection_name="test_collection",
embedding_model_dims=1536, embedding_model_dims=1536,
user=os.getenv('OS_USERNAME'), user=os.getenv("OS_USERNAME"),
password=os.getenv('OS_PASSWORD'), password=os.getenv("OS_PASSWORD"),
verify_certs=False, verify_certs=False,
use_ssl=False use_ssl=False,
) )
self.client_mock.reset_mock() self.client_mock.reset_mock()
@@ -86,29 +84,29 @@ class TestOpenSearchDB(unittest.TestCase):
vectors = [[0.1] * 1536, [0.2] * 1536] vectors = [[0.1] * 1536, [0.2] * 1536]
payloads = [{"key1": "value1"}, {"key2": "value2"}] payloads = [{"key1": "value1"}, {"key2": "value2"}]
ids = ["id1", "id2"] ids = ["id1", "id2"]
# Mock the index method # Mock the index method
self.client_mock.index = MagicMock() self.client_mock.index = MagicMock()
results = self.os_db.insert(vectors=vectors, payloads=payloads, ids=ids) results = self.os_db.insert(vectors=vectors, payloads=payloads, ids=ids)
# Verify index was called twice (once for each vector) # Verify index was called twice (once for each vector)
self.assertEqual(self.client_mock.index.call_count, 2) self.assertEqual(self.client_mock.index.call_count, 2)
# Check first call # Check first call
first_call = self.client_mock.index.call_args_list[0] first_call = self.client_mock.index.call_args_list[0]
self.assertEqual(first_call[1]["index"], "test_collection") self.assertEqual(first_call[1]["index"], "test_collection")
self.assertEqual(first_call[1]["body"]["vector_field"], vectors[0]) self.assertEqual(first_call[1]["body"]["vector_field"], vectors[0])
self.assertEqual(first_call[1]["body"]["payload"], payloads[0]) self.assertEqual(first_call[1]["body"]["payload"], payloads[0])
self.assertEqual(first_call[1]["body"]["id"], ids[0]) self.assertEqual(first_call[1]["body"]["id"], ids[0])
# Check second call # Check second call
second_call = self.client_mock.index.call_args_list[1] second_call = self.client_mock.index.call_args_list[1]
self.assertEqual(second_call[1]["index"], "test_collection") self.assertEqual(second_call[1]["index"], "test_collection")
self.assertEqual(second_call[1]["body"]["vector_field"], vectors[1]) self.assertEqual(second_call[1]["body"]["vector_field"], vectors[1])
self.assertEqual(second_call[1]["body"]["payload"], payloads[1]) self.assertEqual(second_call[1]["body"]["payload"], payloads[1])
self.assertEqual(second_call[1]["body"]["id"], ids[1]) self.assertEqual(second_call[1]["body"]["id"], ids[1])
# Check results # Check results
self.assertEqual(len(results), 2) self.assertEqual(len(results), 2)
self.assertEqual(results[0].id, "id1") self.assertEqual(results[0].id, "id1")
@@ -132,7 +130,7 @@ class TestOpenSearchDB(unittest.TestCase):
self.client_mock.search.return_value = {"hits": {"hits": []}} self.client_mock.search.return_value = {"hits": {"hits": []}}
result = self.os_db.get("nonexistent") result = self.os_db.get("nonexistent")
self.assertIsNone(result) self.assertIsNone(result)
def test_update(self): def test_update(self):
vector = [0.3] * 1536 vector = [0.3] * 1536
payload = {"key3": "value3"} payload = {"key3": "value3"}
@@ -152,7 +150,17 @@ class TestOpenSearchDB(unittest.TestCase):
self.assertEqual(result, ["test_collection"]) self.assertEqual(result, ["test_collection"])
def test_search(self): def test_search(self):
mock_response = {"hits": {"hits": [{"_id": "id1", "_score": 0.8, "_source": {"vector_field": [0.1] * 1536, "id": "id1", "payload": {"key1": "value1"}}}]}} mock_response = {
"hits": {
"hits": [
{
"_id": "id1",
"_score": 0.8,
"_source": {"vector_field": [0.1] * 1536, "id": "id1", "payload": {"key1": "value1"}},
}
]
}
}
self.client_mock.search.return_value = mock_response self.client_mock.search.return_value = mock_response
vectors = [[0.1] * 1536] vectors = [[0.1] * 1536]
results = self.os_db.search(query="", vectors=vectors, limit=5) results = self.os_db.search(query="", vectors=vectors, limit=5)
@@ -179,12 +187,11 @@ class TestOpenSearchDB(unittest.TestCase):
self.os_db.delete_col() self.os_db.delete_col()
self.client_mock.indices.delete.assert_called_once_with(index="test_collection") self.client_mock.indices.delete.assert_called_once_with(index="test_collection")
def test_init_with_http_auth(self): def test_init_with_http_auth(self):
mock_credentials = MagicMock() mock_credentials = MagicMock()
mock_signer = AWSV4SignerAuth(mock_credentials, "us-east-1", "es") mock_signer = AWSV4SignerAuth(mock_credentials, "us-east-1", "es")
with patch('mem0.vector_stores.opensearch.OpenSearch') as mock_opensearch: with patch("mem0.vector_stores.opensearch.OpenSearch") as mock_opensearch:
OpenSearchDB( OpenSearchDB(
host="localhost", host="localhost",
port=9200, port=9200,
@@ -192,7 +199,7 @@ class TestOpenSearchDB(unittest.TestCase):
embedding_model_dims=1536, embedding_model_dims=1536,
http_auth=mock_signer, http_auth=mock_signer,
verify_certs=True, verify_certs=True,
use_ssl=True use_ssl=True,
) )
# Verify OpenSearch was initialized with correct params # Verify OpenSearch was initialized with correct params
@@ -202,5 +209,5 @@ class TestOpenSearchDB(unittest.TestCase):
use_ssl=True, use_ssl=True,
verify_certs=True, verify_certs=True,
connection_class=unittest.mock.ANY, connection_class=unittest.mock.ANY,
pool_maxsize=20 pool_maxsize=20,
) )

View File

@@ -12,6 +12,7 @@ def mock_pinecone_client():
client.list_indexes.return_value.names.return_value = [] client.list_indexes.return_value.names.return_value = []
return client return client
@pytest.fixture @pytest.fixture
def pinecone_db(mock_pinecone_client): def pinecone_db(mock_pinecone_client):
return PineconeDB( return PineconeDB(
@@ -25,13 +26,14 @@ def pinecone_db(mock_pinecone_client):
hybrid_search=False, hybrid_search=False,
metric="cosine", metric="cosine",
batch_size=100, batch_size=100,
extra_params=None extra_params=None,
) )
def test_create_col_existing_index(mock_pinecone_client): def test_create_col_existing_index(mock_pinecone_client):
# Set up the mock before creating the PineconeDB object # Set up the mock before creating the PineconeDB object
mock_pinecone_client.list_indexes.return_value.names.return_value = ["test_index"] mock_pinecone_client.list_indexes.return_value.names.return_value = ["test_index"]
pinecone_db = PineconeDB( pinecone_db = PineconeDB(
collection_name="test_index", collection_name="test_index",
embedding_model_dims=128, embedding_model_dims=128,
@@ -43,21 +45,23 @@ def test_create_col_existing_index(mock_pinecone_client):
hybrid_search=False, hybrid_search=False,
metric="cosine", metric="cosine",
batch_size=100, batch_size=100,
extra_params=None extra_params=None,
) )
# Reset the mock to verify it wasn't called during the test # Reset the mock to verify it wasn't called during the test
mock_pinecone_client.create_index.reset_mock() mock_pinecone_client.create_index.reset_mock()
pinecone_db.create_col(128, "cosine") pinecone_db.create_col(128, "cosine")
mock_pinecone_client.create_index.assert_not_called() mock_pinecone_client.create_index.assert_not_called()
def test_create_col_new_index(pinecone_db, mock_pinecone_client): def test_create_col_new_index(pinecone_db, mock_pinecone_client):
mock_pinecone_client.list_indexes.return_value.names.return_value = [] mock_pinecone_client.list_indexes.return_value.names.return_value = []
pinecone_db.create_col(128, "cosine") pinecone_db.create_col(128, "cosine")
mock_pinecone_client.create_index.assert_called() mock_pinecone_client.create_index.assert_called()
def test_insert_vectors(pinecone_db): def test_insert_vectors(pinecone_db):
vectors = [[0.1] * 128, [0.2] * 128] vectors = [[0.1] * 128, [0.2] * 128]
payloads = [{"name": "vector1"}, {"name": "vector2"}] payloads = [{"name": "vector1"}, {"name": "vector2"}]
@@ -65,56 +69,61 @@ def test_insert_vectors(pinecone_db):
pinecone_db.insert(vectors, payloads, ids) pinecone_db.insert(vectors, payloads, ids)
pinecone_db.index.upsert.assert_called() pinecone_db.index.upsert.assert_called()
def test_search_vectors(pinecone_db): def test_search_vectors(pinecone_db):
pinecone_db.index.query.return_value.matches = [{"id": "id1", "score": 0.9, "metadata": {"name": "vector1"}}] pinecone_db.index.query.return_value.matches = [{"id": "id1", "score": 0.9, "metadata": {"name": "vector1"}}]
results = pinecone_db.search("test query",[0.1] * 128, limit=1) results = pinecone_db.search("test query", [0.1] * 128, limit=1)
assert len(results) == 1 assert len(results) == 1
assert results[0].id == "id1" assert results[0].id == "id1"
assert results[0].score == 0.9 assert results[0].score == 0.9
def test_update_vector(pinecone_db): def test_update_vector(pinecone_db):
pinecone_db.update("id1", vector=[0.5] * 128, payload={"name": "updated"}) pinecone_db.update("id1", vector=[0.5] * 128, payload={"name": "updated"})
pinecone_db.index.upsert.assert_called() pinecone_db.index.upsert.assert_called()
def test_get_vector_found(pinecone_db): def test_get_vector_found(pinecone_db):
# Looking at the _parse_output method, it expects a Vector object # Looking at the _parse_output method, it expects a Vector object
# or a list of dictionaries, not a dictionary with an 'id' field # or a list of dictionaries, not a dictionary with an 'id' field
# Create a mock Vector object # Create a mock Vector object
from pinecone.data.dataclasses.vector import Vector from pinecone.data.dataclasses.vector import Vector
mock_vector = Vector(
id="id1", mock_vector = Vector(id="id1", values=[0.1] * 128, metadata={"name": "vector1"})
values=[0.1] * 128,
metadata={"name": "vector1"}
)
# Mock the fetch method to return the mock response object # Mock the fetch method to return the mock response object
mock_response = MagicMock() mock_response = MagicMock()
mock_response.vectors = {"id1": mock_vector} mock_response.vectors = {"id1": mock_vector}
pinecone_db.index.fetch.return_value = mock_response pinecone_db.index.fetch.return_value = mock_response
result = pinecone_db.get("id1") result = pinecone_db.get("id1")
assert result is not None assert result is not None
assert result.id == "id1" assert result.id == "id1"
assert result.payload == {"name": "vector1"} assert result.payload == {"name": "vector1"}
def test_delete_vector(pinecone_db): def test_delete_vector(pinecone_db):
pinecone_db.delete("id1") pinecone_db.delete("id1")
pinecone_db.index.delete.assert_called_with(ids=["id1"]) pinecone_db.index.delete.assert_called_with(ids=["id1"])
def test_get_vector_not_found(pinecone_db): def test_get_vector_not_found(pinecone_db):
pinecone_db.index.fetch.return_value.vectors = {} pinecone_db.index.fetch.return_value.vectors = {}
result = pinecone_db.get("id1") result = pinecone_db.get("id1")
assert result is None assert result is None
def test_list_cols(pinecone_db): def test_list_cols(pinecone_db):
pinecone_db.list_cols() pinecone_db.list_cols()
pinecone_db.client.list_indexes.assert_called() pinecone_db.client.list_indexes.assert_called()
def test_delete_col(pinecone_db): def test_delete_col(pinecone_db):
pinecone_db.delete_col() pinecone_db.delete_col()
pinecone_db.client.delete_index.assert_called_with("test_index") pinecone_db.client.delete_index.assert_called_with("test_index")
def test_col_info(pinecone_db): def test_col_info(pinecone_db):
pinecone_db.col_info() pinecone_db.col_info()
pinecone_db.client.describe_index.assert_called_with("test_index") pinecone_db.client.describe_index.assert_called_with("test_index")

View File

@@ -37,7 +37,7 @@ def supabase_instance(mock_vecs_client, mock_collection):
index_method=IndexMethod.HNSW, index_method=IndexMethod.HNSW,
index_measure=IndexMeasure.COSINE, index_measure=IndexMeasure.COSINE,
) )
# Manually set the collection attribute since we're mocking the initialization # Manually set the collection attribute since we're mocking the initialization
instance.collection = mock_collection instance.collection = mock_collection
return instance return instance
@@ -46,14 +46,8 @@ def supabase_instance(mock_vecs_client, mock_collection):
def test_create_col(supabase_instance, mock_vecs_client, mock_collection): def test_create_col(supabase_instance, mock_vecs_client, mock_collection):
supabase_instance.create_col(1536) supabase_instance.create_col(1536)
mock_vecs_client.return_value.get_or_create_collection.assert_called_with( mock_vecs_client.return_value.get_or_create_collection.assert_called_with(name="test_collection", dimension=1536)
name="test_collection", mock_collection.create_index.assert_called_with(method="hnsw", measure="cosine_distance")
dimension=1536
)
mock_collection.create_index.assert_called_with(
method="hnsw",
measure="cosine_distance"
)
def test_insert_vectors(supabase_instance, mock_collection): def test_insert_vectors(supabase_instance, mock_collection):
@@ -63,18 +57,12 @@ def test_insert_vectors(supabase_instance, mock_collection):
supabase_instance.insert(vectors=vectors, payloads=payloads, ids=ids) supabase_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
expected_records = [ expected_records = [("id1", [0.1, 0.2, 0.3], {"name": "vector1"}), ("id2", [0.4, 0.5, 0.6], {"name": "vector2"})]
("id1", [0.1, 0.2, 0.3], {"name": "vector1"}),
("id2", [0.4, 0.5, 0.6], {"name": "vector2"})
]
mock_collection.upsert.assert_called_once_with(expected_records) mock_collection.upsert.assert_called_once_with(expected_records)
def test_search_vectors(supabase_instance, mock_collection): def test_search_vectors(supabase_instance, mock_collection):
mock_results = [ mock_results = [("id1", 0.9, {"name": "vector1"}), ("id2", 0.8, {"name": "vector2"})]
("id1", 0.9, {"name": "vector1"}),
("id2", 0.8, {"name": "vector2"})
]
mock_collection.query.return_value = mock_results mock_collection.query.return_value = mock_results
vectors = [[0.1, 0.2, 0.3]] vectors = [[0.1, 0.2, 0.3]]
@@ -82,11 +70,7 @@ def test_search_vectors(supabase_instance, mock_collection):
results = supabase_instance.search(query="", vectors=vectors, limit=2, filters=filters) results = supabase_instance.search(query="", vectors=vectors, limit=2, filters=filters)
mock_collection.query.assert_called_once_with( mock_collection.query.assert_called_once_with(
data=vectors, data=vectors, limit=2, filters={"category": {"$eq": "test"}}, include_metadata=True, include_value=True
limit=2,
filters={"category": {"$eq": "test"}},
include_metadata=True,
include_value=True
) )
assert len(results) == 2 assert len(results) == 2
@@ -129,11 +113,8 @@ def test_get_vector(supabase_instance, mock_collection):
def test_list_vectors(supabase_instance, mock_collection): def test_list_vectors(supabase_instance, mock_collection):
mock_query_results = [("id1", 0.9, {}), ("id2", 0.8, {})] mock_query_results = [("id1", 0.9, {}), ("id2", 0.8, {})]
mock_fetch_results = [ mock_fetch_results = [("id1", [0.1, 0.2, 0.3], {"name": "vector1"}), ("id2", [0.4, 0.5, 0.6], {"name": "vector2"})]
("id1", [0.1, 0.2, 0.3], {"name": "vector1"}),
("id2", [0.4, 0.5, 0.6], {"name": "vector2"})
]
mock_collection.query.return_value = mock_query_results mock_collection.query.return_value = mock_query_results
mock_collection.fetch.return_value = mock_fetch_results mock_collection.fetch.return_value = mock_fetch_results
@@ -153,10 +134,7 @@ def test_col_info(supabase_instance, mock_collection):
"name": "test_collection", "name": "test_collection",
"count": 100, "count": 100,
"dimension": 1536, "dimension": 1536,
"index": { "index": {"method": "hnsw", "metric": "cosine_distance"},
"method": "hnsw",
"metric": "cosine_distance"
}
} }
@@ -168,10 +146,7 @@ def test_preprocess_filters(supabase_instance):
# Test multiple filters # Test multiple filters
multi_filter = {"category": "test", "type": "document"} multi_filter = {"category": "test", "type": "document"}
assert supabase_instance._preprocess_filters(multi_filter) == { assert supabase_instance._preprocess_filters(multi_filter) == {
"$and": [ "$and": [{"category": {"$eq": "test"}}, {"type": {"$eq": "document"}}]
{"category": {"$eq": "test"}},
{"type": {"$eq": "document"}}
]
} }
# Test None filters # Test None filters

View File

@@ -29,9 +29,7 @@ def upstash_instance(mock_index):
@pytest.fixture @pytest.fixture
def upstash_instance_with_embeddings(mock_index): def upstash_instance_with_embeddings(mock_index):
return UpstashVector( return UpstashVector(client=mock_index.return_value, collection_name="ns", enable_embeddings=True)
client=mock_index.return_value, collection_name="ns", enable_embeddings=True
)
def test_insert_vectors(upstash_instance, mock_index): def test_insert_vectors(upstash_instance, mock_index):
@@ -52,12 +50,8 @@ def test_insert_vectors(upstash_instance, mock_index):
def test_search_vectors(upstash_instance, mock_index): def test_search_vectors(upstash_instance, mock_index):
mock_result = [ mock_result = [
QueryResult( QueryResult(id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None),
id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None QueryResult(id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data=None),
),
QueryResult(
id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data=None
),
] ]
upstash_instance.client.query_many.return_value = [mock_result] upstash_instance.client.query_many.return_value = [mock_result]
@@ -93,9 +87,7 @@ def test_delete_vector(upstash_instance):
upstash_instance.delete(vector_id=vector_id) upstash_instance.delete(vector_id=vector_id)
upstash_instance.client.delete.assert_called_once_with( upstash_instance.client.delete.assert_called_once_with(ids=[vector_id], namespace="ns")
ids=[vector_id], namespace="ns"
)
def test_update_vector(upstash_instance): def test_update_vector(upstash_instance):
@@ -115,18 +107,12 @@ def test_update_vector(upstash_instance):
def test_get_vector(upstash_instance): def test_get_vector(upstash_instance):
mock_result = [ mock_result = [QueryResult(id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None)]
QueryResult(
id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None
)
]
upstash_instance.client.fetch.return_value = mock_result upstash_instance.client.fetch.return_value = mock_result
result = upstash_instance.get(vector_id="id1") result = upstash_instance.get(vector_id="id1")
upstash_instance.client.fetch.assert_called_once_with( upstash_instance.client.fetch.assert_called_once_with(ids=["id1"], namespace="ns", include_metadata=True)
ids=["id1"], namespace="ns", include_metadata=True
)
assert result.id == "id1" assert result.id == "id1"
assert result.payload == {"name": "vector1"} assert result.payload == {"name": "vector1"}
@@ -134,15 +120,9 @@ def test_get_vector(upstash_instance):
def test_list_vectors(upstash_instance): def test_list_vectors(upstash_instance):
mock_result = [ mock_result = [
QueryResult( QueryResult(id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None),
id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None QueryResult(id="id2", score=None, vector=None, metadata={"name": "vector2"}, data=None),
), QueryResult(id="id3", score=None, vector=None, metadata={"name": "vector3"}, data=None),
QueryResult(
id="id2", score=None, vector=None, metadata={"name": "vector2"}, data=None
),
QueryResult(
id="id3", score=None, vector=None, metadata={"name": "vector3"}, data=None
),
] ]
handler = MagicMock() handler = MagicMock()
@@ -204,12 +184,8 @@ def test_insert_vectors_with_embeddings(upstash_instance_with_embeddings, mock_i
def test_search_vectors_with_embeddings(upstash_instance_with_embeddings, mock_index): def test_search_vectors_with_embeddings(upstash_instance_with_embeddings, mock_index):
mock_result = [ mock_result = [
QueryResult( QueryResult(id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data="data1"),
id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data="data1" QueryResult(id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data="data2"),
),
QueryResult(
id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data="data2"
),
] ]
upstash_instance_with_embeddings.client.query.return_value = mock_result upstash_instance_with_embeddings.client.query.return_value = mock_result
@@ -260,9 +236,7 @@ def test_insert_vectors_with_embeddings_missing_data(upstash_instance_with_embed
ValueError, ValueError,
match="When embeddings are enabled, all payloads must contain a 'data' field", match="When embeddings are enabled, all payloads must contain a 'data' field",
): ):
upstash_instance_with_embeddings.insert( upstash_instance_with_embeddings.insert(vectors=vectors, payloads=payloads, ids=ids)
vectors=vectors, payloads=payloads, ids=ids
)
def test_update_vector_with_embeddings_missing_data(upstash_instance_with_embeddings): def test_update_vector_with_embeddings_missing_data(upstash_instance_with_embeddings):
@@ -316,18 +290,12 @@ def test_get_vector_not_found(upstash_instance):
result = upstash_instance.get(vector_id="nonexistent") result = upstash_instance.get(vector_id="nonexistent")
upstash_instance.client.fetch.assert_called_once_with( upstash_instance.client.fetch.assert_called_once_with(ids=["nonexistent"], namespace="ns", include_metadata=True)
ids=["nonexistent"], namespace="ns", include_metadata=True
)
assert result is None assert result is None
def test_search_vectors_empty_filters(upstash_instance): def test_search_vectors_empty_filters(upstash_instance):
mock_result = [ mock_result = [QueryResult(id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None)]
QueryResult(
id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None
)
]
upstash_instance.client.query_many.return_value = [mock_result] upstash_instance.client.query_many.return_value = [mock_result]
vectors = [[0.1, 0.2, 0.3]] vectors = [[0.1, 0.2, 0.3]]

View File

@@ -14,47 +14,50 @@ from mem0.vector_stores.vertex_ai_vector_search import GoogleMatchingEngine
@pytest.fixture @pytest.fixture
def mock_vertex_ai(): def mock_vertex_ai():
with patch('google.cloud.aiplatform.MatchingEngineIndex') as mock_index, \ with (
patch('google.cloud.aiplatform.MatchingEngineIndexEndpoint') as mock_endpoint, \ patch("google.cloud.aiplatform.MatchingEngineIndex") as mock_index,
patch('google.cloud.aiplatform.init') as mock_init: patch("google.cloud.aiplatform.MatchingEngineIndexEndpoint") as mock_endpoint,
patch("google.cloud.aiplatform.init") as mock_init,
):
mock_index_instance = Mock() mock_index_instance = Mock()
mock_endpoint_instance = Mock() mock_endpoint_instance = Mock()
yield { yield {
'index': mock_index_instance, "index": mock_index_instance,
'endpoint': mock_endpoint_instance, "endpoint": mock_endpoint_instance,
'init': mock_init, "init": mock_init,
'index_class': mock_index, "index_class": mock_index,
'endpoint_class': mock_endpoint "endpoint_class": mock_endpoint,
} }
@pytest.fixture @pytest.fixture
def config(): def config():
return GoogleMatchingEngineConfig( return GoogleMatchingEngineConfig(
project_id='test-project', project_id="test-project",
project_number='123456789', project_number="123456789",
region='us-central1', region="us-central1",
endpoint_id='test-endpoint', endpoint_id="test-endpoint",
index_id='test-index', index_id="test-index",
deployment_index_id='test-deployment', deployment_index_id="test-deployment",
collection_name='test-collection', collection_name="test-collection",
vector_search_api_endpoint='test.vertexai.goog' vector_search_api_endpoint="test.vertexai.goog",
) )
@pytest.fixture @pytest.fixture
def vector_store(config, mock_vertex_ai): def vector_store(config, mock_vertex_ai):
mock_vertex_ai['index_class'].return_value = mock_vertex_ai['index'] mock_vertex_ai["index_class"].return_value = mock_vertex_ai["index"]
mock_vertex_ai['endpoint_class'].return_value = mock_vertex_ai['endpoint'] mock_vertex_ai["endpoint_class"].return_value = mock_vertex_ai["endpoint"]
return GoogleMatchingEngine(**config.model_dump()) return GoogleMatchingEngine(**config.model_dump())
def test_initialization(vector_store, mock_vertex_ai, config): def test_initialization(vector_store, mock_vertex_ai, config):
"""Test proper initialization of GoogleMatchingEngine""" """Test proper initialization of GoogleMatchingEngine"""
mock_vertex_ai['init'].assert_called_once_with( mock_vertex_ai["init"].assert_called_once_with(project=config.project_id, location=config.region)
project=config.project_id,
location=config.region
)
expected_index_path = f"projects/{config.project_number}/locations/{config.region}/indexes/{config.index_id}" expected_index_path = f"projects/{config.project_number}/locations/{config.region}/indexes/{config.index_id}"
mock_vertex_ai['index_class'].assert_called_once_with(index_name=expected_index_path) mock_vertex_ai["index_class"].assert_called_once_with(index_name=expected_index_path)
def test_insert_vectors(vector_store, mock_vertex_ai): def test_insert_vectors(vector_store, mock_vertex_ai):
"""Test inserting vectors with payloads""" """Test inserting vectors with payloads"""
@@ -64,13 +67,14 @@ def test_insert_vectors(vector_store, mock_vertex_ai):
vector_store.insert(vectors=vectors, payloads=payloads, ids=ids) vector_store.insert(vectors=vectors, payloads=payloads, ids=ids)
mock_vertex_ai['index'].upsert_datapoints.assert_called_once() mock_vertex_ai["index"].upsert_datapoints.assert_called_once()
call_args = mock_vertex_ai['index'].upsert_datapoints.call_args[1] call_args = mock_vertex_ai["index"].upsert_datapoints.call_args[1]
assert len(call_args['datapoints']) == 1 assert len(call_args["datapoints"]) == 1
datapoint_str = str(call_args['datapoints'][0]) datapoint_str = str(call_args["datapoints"][0])
assert "test-id" in datapoint_str assert "test-id" in datapoint_str
assert "0.1" in datapoint_str and "0.2" in datapoint_str and "0.3" in datapoint_str assert "0.1" in datapoint_str and "0.2" in datapoint_str and "0.3" in datapoint_str
def test_search_vectors(vector_store, mock_vertex_ai): def test_search_vectors(vector_store, mock_vertex_ai):
"""Test searching vectors with filters""" """Test searching vectors with filters"""
vectors = [[0.1, 0.2, 0.3]] vectors = [[0.1, 0.2, 0.3]]
@@ -85,7 +89,7 @@ def test_search_vectors(vector_store, mock_vertex_ai):
mock_restrict.allow_list = ["test_user"] mock_restrict.allow_list = ["test_user"]
mock_restrict.name = "user_id" mock_restrict.name = "user_id"
mock_restrict.allow_tokens = ["test_user"] mock_restrict.allow_tokens = ["test_user"]
mock_datapoint.restricts = [mock_restrict] mock_datapoint.restricts = [mock_restrict]
mock_neighbor = Mock() mock_neighbor = Mock()
@@ -94,16 +98,16 @@ def test_search_vectors(vector_store, mock_vertex_ai):
mock_neighbor.datapoint = mock_datapoint mock_neighbor.datapoint = mock_datapoint
mock_neighbor.restricts = [mock_restrict] mock_neighbor.restricts = [mock_restrict]
mock_vertex_ai['endpoint'].find_neighbors.return_value = [[mock_neighbor]] mock_vertex_ai["endpoint"].find_neighbors.return_value = [[mock_neighbor]]
results = vector_store.search(query="", vectors=vectors, filters=filters, limit=1) results = vector_store.search(query="", vectors=vectors, filters=filters, limit=1)
mock_vertex_ai['endpoint'].find_neighbors.assert_called_once_with( mock_vertex_ai["endpoint"].find_neighbors.assert_called_once_with(
deployed_index_id=vector_store.deployment_index_id, deployed_index_id=vector_store.deployment_index_id,
queries=[vectors], queries=[vectors],
num_neighbors=1, num_neighbors=1,
filter=[Namespace("user_id", ["test_user"], [])], filter=[Namespace("user_id", ["test_user"], [])],
return_full_datapoint=True return_full_datapoint=True,
) )
assert len(results) == 1 assert len(results) == 1
@@ -111,29 +115,27 @@ def test_search_vectors(vector_store, mock_vertex_ai):
assert results[0].score == 0.1 assert results[0].score == 0.1
assert results[0].payload == {"user_id": "test_user"} assert results[0].payload == {"user_id": "test_user"}
def test_delete(vector_store, mock_vertex_ai): def test_delete(vector_store, mock_vertex_ai):
"""Test deleting vectors""" """Test deleting vectors"""
vector_id = "test-id" vector_id = "test-id"
remove_mock = Mock() remove_mock = Mock()
with patch.object(GoogleMatchingEngine, 'delete', wraps=vector_store.delete) as delete_spy: with patch.object(GoogleMatchingEngine, "delete", wraps=vector_store.delete) as delete_spy:
with patch.object(vector_store.index, 'remove_datapoints', remove_mock): with patch.object(vector_store.index, "remove_datapoints", remove_mock):
vector_store.delete(ids=[vector_id]) vector_store.delete(ids=[vector_id])
delete_spy.assert_called_once_with(ids=[vector_id]) delete_spy.assert_called_once_with(ids=[vector_id])
remove_mock.assert_called_once_with(datapoint_ids=[vector_id]) remove_mock.assert_called_once_with(datapoint_ids=[vector_id])
def test_error_handling(vector_store, mock_vertex_ai): def test_error_handling(vector_store, mock_vertex_ai):
"""Test error handling during operations""" """Test error handling during operations"""
mock_vertex_ai['index'].upsert_datapoints.side_effect = exceptions.InvalidArgument("Invalid request") mock_vertex_ai["index"].upsert_datapoints.side_effect = exceptions.InvalidArgument("Invalid request")
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
vector_store.insert( vector_store.insert(vectors=[[0.1, 0.2, 0.3]], payloads=[{"name": "test"}], ids=["test-id"])
vectors=[[0.1, 0.2, 0.3]],
payloads=[{"name": "test"}],
ids=["test-id"]
)
assert isinstance(exc_info.value, exceptions.InvalidArgument) assert isinstance(exc_info.value, exceptions.InvalidArgument)
assert "Invalid request" in str(exc_info.value) assert "Invalid request" in str(exc_info.value)

View File

@@ -76,15 +76,15 @@
# self.client_mock.batch = MagicMock() # self.client_mock.batch = MagicMock()
# self.client_mock.batch.fixed_size.return_value.__enter__.return_value = MagicMock() # self.client_mock.batch.fixed_size.return_value.__enter__.return_value = MagicMock()
# self.client_mock.collections.get.return_value.data.insert_many.return_value = { # self.client_mock.collections.get.return_value.data.insert_many.return_value = {
# "results": [{"id": "id1"}, {"id": "id2"}] # "results": [{"id": "id1"}, {"id": "id2"}]
# } # }
# vectors = [[0.1] * 1536, [0.2] * 1536] # vectors = [[0.1] * 1536, [0.2] * 1536]
# payloads = [{"key1": "value1"}, {"key2": "value2"}] # payloads = [{"key1": "value1"}, {"key2": "value2"}]
# ids = [str(uuid.uuid4()), str(uuid.uuid4())] # ids = [str(uuid.uuid4()), str(uuid.uuid4())]
# results = self.weaviate_db.insert(vectors=vectors, payloads=payloads, ids=ids) # results = self.weaviate_db.insert(vectors=vectors, payloads=payloads, ids=ids)
# def test_get(self): # def test_get(self):
@@ -108,7 +108,7 @@
# result = self.weaviate_db.get(vector_id=valid_uuid) # result = self.weaviate_db.get(vector_id=valid_uuid)
# assert result.id == valid_uuid # assert result.id == valid_uuid
# expected_payload = mock_response.properties.copy() # expected_payload = mock_response.properties.copy()
# expected_payload["id"] = valid_uuid # expected_payload["id"] = valid_uuid
@@ -131,10 +131,10 @@
# "metadata": {"distance": 0.2} # "metadata": {"distance": 0.2}
# } # }
# ] # ]
# mock_response = MagicMock() # mock_response = MagicMock()
# mock_response.objects = [] # mock_response.objects = []
# for obj in mock_objects: # for obj in mock_objects:
# mock_obj = MagicMock() # mock_obj = MagicMock()
# mock_obj.uuid = obj["uuid"] # mock_obj.uuid = obj["uuid"]
@@ -142,16 +142,16 @@
# mock_obj.metadata = MagicMock() # mock_obj.metadata = MagicMock()
# mock_obj.metadata.distance = obj["metadata"]["distance"] # mock_obj.metadata.distance = obj["metadata"]["distance"]
# mock_response.objects.append(mock_obj) # mock_response.objects.append(mock_obj)
# mock_hybrid = MagicMock() # mock_hybrid = MagicMock()
# self.client_mock.collections.get.return_value.query.hybrid = mock_hybrid # self.client_mock.collections.get.return_value.query.hybrid = mock_hybrid
# mock_hybrid.return_value = mock_response # mock_hybrid.return_value = mock_response
# vectors = [[0.1] * 1536] # vectors = [[0.1] * 1536]
# results = self.weaviate_db.search(query="", vectors=vectors, limit=5) # results = self.weaviate_db.search(query="", vectors=vectors, limit=5)
# mock_hybrid.assert_called_once() # mock_hybrid.assert_called_once()
# self.assertEqual(len(results), 1) # self.assertEqual(len(results), 1)
# self.assertEqual(results[0].id, "id1") # self.assertEqual(results[0].id, "id1")
# self.assertEqual(results[0].score, 0.8) # self.assertEqual(results[0].score, 0.8)
@@ -163,28 +163,28 @@
# def test_list(self): # def test_list(self):
# mock_objects = [] # mock_objects = []
# mock_obj1 = MagicMock() # mock_obj1 = MagicMock()
# mock_obj1.uuid = "id1" # mock_obj1.uuid = "id1"
# mock_obj1.properties = {"key1": "value1"} # mock_obj1.properties = {"key1": "value1"}
# mock_objects.append(mock_obj1) # mock_objects.append(mock_obj1)
# mock_obj2 = MagicMock() # mock_obj2 = MagicMock()
# mock_obj2.uuid = "id2" # mock_obj2.uuid = "id2"
# mock_obj2.properties = {"key2": "value2"} # mock_obj2.properties = {"key2": "value2"}
# mock_objects.append(mock_obj2) # mock_objects.append(mock_obj2)
# mock_response = MagicMock() # mock_response = MagicMock()
# mock_response.objects = mock_objects # mock_response.objects = mock_objects
# mock_fetch = MagicMock() # mock_fetch = MagicMock()
# self.client_mock.collections.get.return_value.query.fetch_objects = mock_fetch # self.client_mock.collections.get.return_value.query.fetch_objects = mock_fetch
# mock_fetch.return_value = mock_response # mock_fetch.return_value = mock_response
# results = self.weaviate_db.list(limit=10) # results = self.weaviate_db.list(limit=10)
# mock_fetch.assert_called_once() # mock_fetch.assert_called_once()
# # Verify results # # Verify results
# self.assertEqual(len(results), 1) # self.assertEqual(len(results), 1)
# self.assertEqual(len(results[0]), 2) # self.assertEqual(len(results[0]), 2)