Formatting (#2750)

This commit is contained in:
Dev Khant
2025-05-22 01:17:29 +05:30
committed by GitHub
parent dff91154a7
commit d85fcda037
71 changed files with 1391 additions and 1823 deletions

View File

@@ -56,6 +56,7 @@ jobs:
run: |
make install_all
pip install -e ".[test]"
pip install pinecone pinecone-text
if: steps.cached-hatch-dependencies.outputs.cache-hit != 'true'
- name: Run Formatting
run: |

View File

@@ -13,7 +13,7 @@
"import anthropic\n",
"\n",
"# Set up environment variables\n",
"os.environ[\"OPENAI_API_KEY\"] = \"your_openai_api_key\" # needed for embedding model\n",
"os.environ[\"OPENAI_API_KEY\"] = \"your_openai_api_key\" # needed for embedding model\n",
"os.environ[\"ANTHROPIC_API_KEY\"] = \"your_anthropic_api_key\""
]
},
@@ -33,7 +33,7 @@
" \"model\": \"claude-3-5-sonnet-latest\",\n",
" \"temperature\": 0.1,\n",
" \"max_tokens\": 2000,\n",
" }\n",
" },\n",
" }\n",
" }\n",
" self.client = anthropic.Client(api_key=os.environ[\"ANTHROPIC_API_KEY\"])\n",
@@ -50,11 +50,7 @@
" - Keep track of open issues and follow-ups\n",
" \"\"\"\n",
"\n",
" def store_customer_interaction(self,\n",
" user_id: str,\n",
" message: str,\n",
" response: str,\n",
" metadata: Dict = None):\n",
" def store_customer_interaction(self, user_id: str, message: str, response: str, metadata: Dict = None):\n",
" \"\"\"Store customer interaction in memory.\"\"\"\n",
" if metadata is None:\n",
" metadata = {}\n",
@@ -63,24 +59,17 @@
" metadata[\"timestamp\"] = datetime.now().isoformat()\n",
"\n",
" # Format conversation for storage\n",
" conversation = [\n",
" {\"role\": \"user\", \"content\": message},\n",
" {\"role\": \"assistant\", \"content\": response}\n",
" ]\n",
" conversation = [{\"role\": \"user\", \"content\": message}, {\"role\": \"assistant\", \"content\": response}]\n",
"\n",
" # Store in Mem0\n",
" self.memory.add(\n",
" conversation,\n",
" user_id=user_id,\n",
" metadata=metadata\n",
" )\n",
" self.memory.add(conversation, user_id=user_id, metadata=metadata)\n",
"\n",
" def get_relevant_history(self, user_id: str, query: str) -> List[Dict]:\n",
" \"\"\"Retrieve relevant past interactions.\"\"\"\n",
" return self.memory.search(\n",
" query=query,\n",
" user_id=user_id,\n",
" limit=5 # Adjust based on needs\n",
" limit=5, # Adjust based on needs\n",
" )\n",
"\n",
" def handle_customer_query(self, user_id: str, query: str) -> str:\n",
@@ -112,15 +101,12 @@
" model=\"claude-3-5-sonnet-latest\",\n",
" messages=[{\"role\": \"user\", \"content\": prompt}],\n",
" max_tokens=2000,\n",
" temperature=0.1\n",
" temperature=0.1,\n",
" )\n",
"\n",
" # Store interaction\n",
" self.store_customer_interaction(\n",
" user_id=user_id,\n",
" message=query,\n",
" response=response,\n",
" metadata={\"type\": \"support_query\"}\n",
" user_id=user_id, message=query, response=response, metadata={\"type\": \"support_query\"}\n",
" )\n",
"\n",
" return response.content[0].text"
@@ -203,12 +189,12 @@
" # Get user input\n",
" query = input()\n",
" print(\"Customer:\", query)\n",
" \n",
"\n",
" # Check if user wants to exit\n",
" if query.lower() == 'exit':\n",
" if query.lower() == \"exit\":\n",
" print(\"Thank you for using our support service. Goodbye!\")\n",
" break\n",
" \n",
"\n",
" # Handle the query and print the response\n",
" response = chatbot.handle_customer_query(user_id, query)\n",
" print(\"Support:\", response, \"\\n\\n\")"

View File

@@ -25,7 +25,8 @@
"source": [
"# Set up ENV Vars\n",
"import os\n",
"os.environ['OPENAI_API_KEY'] = \"sk-xxx\"\n"
"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\""
]
},
{
@@ -133,11 +134,9 @@
"assistant_id = os.environ.get(\"ASSISTANT_ID\", None)\n",
"\n",
"# LLM Configuration\n",
"CACHE_SEED = 42 # choose your poison\n",
"CACHE_SEED = 42 # choose your poison\n",
"llm_config = {\n",
" \"config_list\": [\n",
" {\"model\": \"gpt-4o\", \"api_key\": os.environ[\"OPENAI_API_KEY\"]}\n",
" ],\n",
" \"config_list\": [{\"model\": \"gpt-4o\", \"api_key\": os.environ[\"OPENAI_API_KEY\"]}],\n",
" \"cache_seed\": CACHE_SEED,\n",
" \"timeout\": 120,\n",
" \"temperature\": 0.0,\n",
@@ -348,7 +347,7 @@
"source": [
"# Retrieve the memory\n",
"relevant_memories = MEM0_MEMORY_CLIENT.search(user_query, user_id=USER_ID, limit=3)\n",
"relevant_memories_text = '\\n'.join(mem['memory'] for mem in relevant_memories)\n",
"relevant_memories_text = \"\\n\".join(mem[\"memory\"] for mem in relevant_memories)\n",
"print(\"Relevant memories:\")\n",
"print(relevant_memories_text)\n",
"\n",
@@ -389,8 +388,8 @@
"# - Enables more context-aware and personalized agent responses.\n",
"# - Bridges the gap between human input and AI processing in complex workflows.\n",
"\n",
"class Mem0ProxyCoderAgent(UserProxyAgent):\n",
"\n",
"class Mem0ProxyCoderAgent(UserProxyAgent):\n",
" def __init__(self, *args, **kwargs):\n",
" super().__init__(*args, **kwargs)\n",
" self.memory = MEM0_MEMORY_CLIENT\n",
@@ -399,15 +398,14 @@
" def initiate_chat(self, assistant, message):\n",
" # Retrieve memory for the agent\n",
" agent_memories = self.memory.search(message, agent_id=self.agent_id, limit=3)\n",
" agent_memories_txt = '\\n'.join(mem['memory'] for mem in agent_memories)\n",
" agent_memories_txt = \"\\n\".join(mem[\"memory\"] for mem in agent_memories)\n",
" prompt = f\"{message}\\n Coding Preferences: \\n{str(agent_memories_txt)}\"\n",
" response = super().initiate_chat(assistant, message=prompt)\n",
" # Add new memory after processing the message\n",
" response_dist = response.__dict__ if not isinstance(response, dict) else response\n",
" MEMORY_DATA = [{\"role\": \"user\", \"content\": message}, {\"role\": \"assistant\", \"content\": response_dist}]\n",
" self.memory.add(MEMORY_DATA, agent_id=self.agent_id)\n",
" return response\n",
" "
" return response"
]
},
{
@@ -560,12 +558,12 @@
"from cookbooks.helper.mem0_teachability import Mem0Teachability\n",
"\n",
"teachability = Mem0Teachability(\n",
" verbosity=2, # for visibility of what's happening\n",
" recall_threshold=0.5,\n",
" reset_db=False, # Use True to force-reset the memo DB, and False to use an existing DB.\n",
" agent_id=AGENT_ID,\n",
" memory_client = MEM0_MEMORY_CLIENT,\n",
" )\n",
" verbosity=2, # for visibility of what's happening\n",
" recall_threshold=0.5,\n",
" reset_db=False, # Use True to force-reset the memo DB, and False to use an existing DB.\n",
" agent_id=AGENT_ID,\n",
" memory_client=MEM0_MEMORY_CLIENT,\n",
")\n",
"teachability.add_to_agent(user_proxy)"
]
},

View File

@@ -14,46 +14,47 @@ def process_item(item_data):
local_results = defaultdict(list)
for item in v:
gt_answer = str(item['answer'])
pred_answer = str(item['response'])
category = str(item['category'])
question = str(item['question'])
gt_answer = str(item["answer"])
pred_answer = str(item["response"])
category = str(item["category"])
question = str(item["question"])
# Skip category 5
if category == '5':
if category == "5":
continue
metrics = calculate_metrics(pred_answer, gt_answer)
bleu_scores = calculate_bleu_scores(pred_answer, gt_answer)
llm_score = evaluate_llm_judge(question, gt_answer, pred_answer)
local_results[k].append({
"question": question,
"answer": gt_answer,
"response": pred_answer,
"category": category,
"bleu_score": bleu_scores["bleu1"],
"f1_score": metrics["f1"],
"llm_score": llm_score
})
local_results[k].append(
{
"question": question,
"answer": gt_answer,
"response": pred_answer,
"category": category,
"bleu_score": bleu_scores["bleu1"],
"f1_score": metrics["f1"],
"llm_score": llm_score,
}
)
return local_results
def main():
parser = argparse.ArgumentParser(description='Evaluate RAG results')
parser.add_argument('--input_file', type=str,
default="results/rag_results_500_k1.json",
help='Path to the input dataset file')
parser.add_argument('--output_file', type=str,
default="evaluation_metrics.json",
help='Path to save the evaluation results')
parser.add_argument('--max_workers', type=int, default=10,
help='Maximum number of worker threads')
parser = argparse.ArgumentParser(description="Evaluate RAG results")
parser.add_argument(
"--input_file", type=str, default="results/rag_results_500_k1.json", help="Path to the input dataset file"
)
parser.add_argument(
"--output_file", type=str, default="evaluation_metrics.json", help="Path to save the evaluation results"
)
parser.add_argument("--max_workers", type=int, default=10, help="Maximum number of worker threads")
args = parser.parse_args()
with open(args.input_file, 'r') as f:
with open(args.input_file, "r") as f:
data = json.load(f)
results = defaultdict(list)
@@ -61,18 +62,16 @@ def main():
# Use ThreadPoolExecutor with specified workers
with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor:
futures = [executor.submit(process_item, item_data)
for item_data in data.items()]
futures = [executor.submit(process_item, item_data) for item_data in data.items()]
for future in tqdm(concurrent.futures.as_completed(futures),
total=len(futures)):
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
local_results = future.result()
with results_lock:
for k, items in local_results.items():
results[k].extend(items)
# Save results to JSON file
with open(args.output_file, 'w') as f:
with open(args.output_file, "w") as f:
json.dump(results, f, indent=4)
print(f"Results saved to {args.output_file}")

View File

@@ -3,7 +3,7 @@ import json
import pandas as pd
# Load the evaluation metrics data
with open('evaluation_metrics.json', 'r') as f:
with open("evaluation_metrics.json", "r") as f:
data = json.load(f)
# Flatten the data into a list of question items
@@ -15,28 +15,20 @@ for key in data:
df = pd.DataFrame(all_items)
# Convert category to numeric type
df['category'] = pd.to_numeric(df['category'])
df["category"] = pd.to_numeric(df["category"])
# Calculate mean scores by category
result = df.groupby('category').agg({
'bleu_score': 'mean',
'f1_score': 'mean',
'llm_score': 'mean'
}).round(4)
result = df.groupby("category").agg({"bleu_score": "mean", "f1_score": "mean", "llm_score": "mean"}).round(4)
# Add count of questions per category
result['count'] = df.groupby('category').size()
result["count"] = df.groupby("category").size()
# Print the results
print("Mean Scores Per Category:")
print(result)
# Calculate overall means
overall_means = df.agg({
'bleu_score': 'mean',
'f1_score': 'mean',
'llm_score': 'mean'
}).round(4)
overall_means = df.agg({"bleu_score": "mean", "f1_score": "mean", "llm_score": "mean"}).round(4)
print("\nOverall Mean Scores:")
print(overall_means)
print(overall_means)

View File

@@ -33,35 +33,34 @@ Do NOT include both CORRECT and WRONG in your response, or it will break the eva
Just return the label CORRECT or WRONG in a json format with the key as "label".
"""
def evaluate_llm_judge(question, gold_answer, generated_answer):
"""Evaluate the generated answer against the gold answer using an LLM judge."""
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[{
"role": "user",
"content": ACCURACY_PROMPT.format(
question=question,
gold_answer=gold_answer,
generated_answer=generated_answer
)
}],
messages=[
{
"role": "user",
"content": ACCURACY_PROMPT.format(
question=question, gold_answer=gold_answer, generated_answer=generated_answer
),
}
],
response_format={"type": "json_object"},
temperature=0.0
temperature=0.0,
)
label = json.loads(response.choices[0].message.content)['label']
label = json.loads(response.choices[0].message.content)["label"]
return 1 if label == "CORRECT" else 0
def main():
"""Main function to evaluate RAG results using LLM judge."""
parser = argparse.ArgumentParser(
description='Evaluate RAG results using LLM judge'
)
parser = argparse.ArgumentParser(description="Evaluate RAG results using LLM judge")
parser.add_argument(
'--input_file',
"--input_file",
type=str,
default="results/default_run_v4_k30_new_graph.json",
help='Path to the input dataset file'
help="Path to the input dataset file",
)
args = parser.parse_args()
@@ -78,10 +77,10 @@ def main():
index = 0
for k, v in data.items():
for x in v:
question = x['question']
gold_answer = x['answer']
generated_answer = x['response']
category = x['category']
question = x["question"]
gold_answer = x["answer"]
generated_answer = x["response"]
category = x["category"]
# Skip category 5
if int(category) == 5:
@@ -92,13 +91,15 @@ def main():
LLM_JUDGE[category].append(label)
# Store the results
RESULTS[index].append({
"question": question,
"gt_answer": gold_answer,
"response": generated_answer,
"category": category,
"llm_label": label
})
RESULTS[index].append(
{
"question": question,
"gt_answer": gold_answer,
"response": generated_answer,
"category": category,
"llm_label": label,
}
)
# Save intermediate results
with open(output_path, "w") as f:
@@ -108,8 +109,7 @@ def main():
print("All categories accuracy:")
for cat, results in LLM_JUDGE.items():
if results: # Only print if there are results for this category
print(f" Category {cat}: {np.mean(results):.4f} "
f"({sum(results)}/{len(results)})")
print(f" Category {cat}: {np.mean(results):.4f} " f"({sum(results)}/{len(results)})")
print("------------------------------------------")
index += 1

View File

@@ -3,7 +3,7 @@ Borrowed from https://github.com/WujiangXu/AgenticMemory/blob/main/utils.py
@article{xu2025mem,
title={A-mem: Agentic memory for llm agents},
author={Xu, Wujiang and Liang, Zujie and Mei, Kai and Gao, Hang and Tan, Juntao
author={Xu, Wujiang and Liang, Zujie and Mei, Kai and Gao, Hang and Tan, Juntao
and Zhang, Yongfeng},
journal={arXiv preprint arXiv:2502.12110},
year={2025}
@@ -26,42 +26,45 @@ from sentence_transformers.util import pytorch_cos_sim
# Download required NLTK data
try:
nltk.download('punkt', quiet=True)
nltk.download('wordnet', quiet=True)
nltk.download("punkt", quiet=True)
nltk.download("wordnet", quiet=True)
except Exception as e:
print(f"Error downloading NLTK data: {e}")
# Initialize SentenceTransformer model (this will be reused)
try:
sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
except Exception as e:
print(f"Warning: Could not load SentenceTransformer model: {e}")
sentence_model = None
def simple_tokenize(text):
"""Simple tokenization function."""
# Convert to string if not already
text = str(text)
return text.lower().replace('.', ' ').replace(',', ' ').replace('!', ' ').replace('?', ' ').split()
return text.lower().replace(".", " ").replace(",", " ").replace("!", " ").replace("?", " ").split()
def calculate_rouge_scores(prediction: str, reference: str) -> Dict[str, float]:
"""Calculate ROUGE scores for prediction against reference."""
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
scores = scorer.score(reference, prediction)
return {
'rouge1_f': scores['rouge1'].fmeasure,
'rouge2_f': scores['rouge2'].fmeasure,
'rougeL_f': scores['rougeL'].fmeasure
"rouge1_f": scores["rouge1"].fmeasure,
"rouge2_f": scores["rouge2"].fmeasure,
"rougeL_f": scores["rougeL"].fmeasure,
}
def calculate_bleu_scores(prediction: str, reference: str) -> Dict[str, float]:
"""Calculate BLEU scores with different n-gram settings."""
pred_tokens = nltk.word_tokenize(prediction.lower())
ref_tokens = [nltk.word_tokenize(reference.lower())]
weights_list = [(1, 0, 0, 0), (0.5, 0.5, 0, 0), (0.33, 0.33, 0.33, 0), (0.25, 0.25, 0.25, 0.25)]
smooth = SmoothingFunction().method1
scores = {}
for n, weights in enumerate(weights_list, start=1):
try:
@@ -69,26 +72,20 @@ def calculate_bleu_scores(prediction: str, reference: str) -> Dict[str, float]:
except Exception as e:
print(f"Error calculating BLEU score: {e}")
score = 0.0
scores[f'bleu{n}'] = score
scores[f"bleu{n}"] = score
return scores
def calculate_bert_scores(prediction: str, reference: str) -> Dict[str, float]:
"""Calculate BERTScore for semantic similarity."""
try:
P, R, F1 = bert_score([prediction], [reference], lang='en', verbose=False)
return {
'bert_precision': P.item(),
'bert_recall': R.item(),
'bert_f1': F1.item()
}
P, R, F1 = bert_score([prediction], [reference], lang="en", verbose=False)
return {"bert_precision": P.item(), "bert_recall": R.item(), "bert_f1": F1.item()}
except Exception as e:
print(f"Error calculating BERTScore: {e}")
return {
'bert_precision': 0.0,
'bert_recall': 0.0,
'bert_f1': 0.0
}
return {"bert_precision": 0.0, "bert_recall": 0.0, "bert_f1": 0.0}
def calculate_meteor_score(prediction: str, reference: str) -> float:
"""Calculate METEOR score for the prediction."""
@@ -98,6 +95,7 @@ def calculate_meteor_score(prediction: str, reference: str) -> float:
print(f"Error calculating METEOR score: {e}")
return 0.0
def calculate_sentence_similarity(prediction: str, reference: str) -> float:
"""Calculate sentence embedding similarity using SentenceBERT."""
if sentence_model is None:
@@ -106,7 +104,7 @@ def calculate_sentence_similarity(prediction: str, reference: str) -> float:
# Encode sentences
embedding1 = sentence_model.encode([prediction], convert_to_tensor=True)
embedding2 = sentence_model.encode([reference], convert_to_tensor=True)
# Calculate cosine similarity
similarity = pytorch_cos_sim(embedding1, embedding2).item()
return float(similarity)
@@ -114,6 +112,7 @@ def calculate_sentence_similarity(prediction: str, reference: str) -> float:
print(f"Error calculating sentence similarity: {e}")
return 0.0
def calculate_metrics(prediction: str, reference: str) -> Dict[str, float]:
"""Calculate comprehensive evaluation metrics for a prediction."""
# Handle empty or None values
@@ -130,31 +129,31 @@ def calculate_metrics(prediction: str, reference: str) -> Dict[str, float]:
"bleu4": 0.0,
"bert_f1": 0.0,
"meteor": 0.0,
"sbert_similarity": 0.0
"sbert_similarity": 0.0,
}
# Convert to strings if they're not already
prediction = str(prediction).strip()
reference = str(reference).strip()
# Calculate exact match
exact_match = int(prediction.lower() == reference.lower())
# Calculate token-based F1 score
pred_tokens = set(simple_tokenize(prediction))
ref_tokens = set(simple_tokenize(reference))
common_tokens = pred_tokens & ref_tokens
if not pred_tokens or not ref_tokens:
f1 = 0.0
else:
precision = len(common_tokens) / len(pred_tokens)
recall = len(common_tokens) / len(ref_tokens)
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
# Calculate all scores
bleu_scores = calculate_bleu_scores(prediction, reference)
# Combine all metrics
metrics = {
"exact_match": exact_match,
@@ -164,48 +163,49 @@ def calculate_metrics(prediction: str, reference: str) -> Dict[str, float]:
return metrics
def aggregate_metrics(all_metrics: List[Dict[str, float]], all_categories: List[int]) -> Dict[str, Dict[str, Union[float, Dict[str, float]]]]:
def aggregate_metrics(
all_metrics: List[Dict[str, float]], all_categories: List[int]
) -> Dict[str, Dict[str, Union[float, Dict[str, float]]]]:
"""Calculate aggregate statistics for all metrics, split by category."""
if not all_metrics:
return {}
# Initialize aggregates for overall and per-category metrics
aggregates = defaultdict(list)
category_aggregates = defaultdict(lambda: defaultdict(list))
# Collect all values for each metric, both overall and per category
for metrics, category in zip(all_metrics, all_categories):
for metric_name, value in metrics.items():
aggregates[metric_name].append(value)
category_aggregates[category][metric_name].append(value)
# Calculate statistics for overall metrics
results = {
"overall": {}
}
results = {"overall": {}}
for metric_name, values in aggregates.items():
results["overall"][metric_name] = {
'mean': statistics.mean(values),
'std': statistics.stdev(values) if len(values) > 1 else 0.0,
'median': statistics.median(values),
'min': min(values),
'max': max(values),
'count': len(values)
"mean": statistics.mean(values),
"std": statistics.stdev(values) if len(values) > 1 else 0.0,
"median": statistics.median(values),
"min": min(values),
"max": max(values),
"count": len(values),
}
# Calculate statistics for each category
for category in sorted(category_aggregates.keys()):
results[f"category_{category}"] = {}
for metric_name, values in category_aggregates[category].items():
if values: # Only calculate if we have values for this category
results[f"category_{category}"][metric_name] = {
'mean': statistics.mean(values),
'std': statistics.stdev(values) if len(values) > 1 else 0.0,
'median': statistics.median(values),
'min': min(values),
'max': max(values),
'count': len(values)
"mean": statistics.mean(values),
"std": statistics.stdev(values) if len(values) > 1 else 0.0,
"median": statistics.median(values),
"min": min(values),
"max": max(values),
"count": len(values),
}
return results

View File

@@ -144,4 +144,4 @@ ANSWER_PROMPT_ZEP = """
Question: {{question}}
Answer:
"""
"""

View File

@@ -21,23 +21,15 @@ class Experiment:
def main():
parser = argparse.ArgumentParser(description='Run memory experiments')
parser.add_argument('--technique_type', choices=TECHNIQUES, default='mem0',
help='Memory technique to use')
parser.add_argument('--method', choices=METHODS, default='add',
help='Method to use')
parser.add_argument('--chunk_size', type=int, default=1000,
help='Chunk size for processing')
parser.add_argument('--output_folder', type=str, default='results/',
help='Output path for results')
parser.add_argument('--top_k', type=int, default=30,
help='Number of top memories to retrieve')
parser.add_argument('--filter_memories', action='store_true', default=False,
help='Whether to filter memories')
parser.add_argument('--is_graph', action='store_true', default=False,
help='Whether to use graph-based search')
parser.add_argument('--num_chunks', type=int, default=1,
help='Number of chunks to process')
parser = argparse.ArgumentParser(description="Run memory experiments")
parser.add_argument("--technique_type", choices=TECHNIQUES, default="mem0", help="Memory technique to use")
parser.add_argument("--method", choices=METHODS, default="add", help="Method to use")
parser.add_argument("--chunk_size", type=int, default=1000, help="Chunk size for processing")
parser.add_argument("--output_folder", type=str, default="results/", help="Output path for results")
parser.add_argument("--top_k", type=int, default=30, help="Number of top memories to retrieve")
parser.add_argument("--filter_memories", action="store_true", default=False, help="Whether to filter memories")
parser.add_argument("--is_graph", action="store_true", default=False, help="Whether to use graph-based search")
parser.add_argument("--num_chunks", type=int, default=1, help="Number of chunks to process")
args = parser.parse_args()
@@ -46,33 +38,18 @@ def main():
if args.technique_type == "mem0":
if args.method == "add":
memory_manager = MemoryADD(
data_path='dataset/locomo10.json',
is_graph=args.is_graph
)
memory_manager = MemoryADD(data_path="dataset/locomo10.json", is_graph=args.is_graph)
memory_manager.process_all_conversations()
elif args.method == "search":
output_file_path = os.path.join(
args.output_folder,
f"mem0_results_top_{args.top_k}_filter_{args.filter_memories}_graph_{args.is_graph}.json"
f"mem0_results_top_{args.top_k}_filter_{args.filter_memories}_graph_{args.is_graph}.json",
)
memory_searcher = MemorySearch(
output_file_path,
args.top_k,
args.filter_memories,
args.is_graph
)
memory_searcher.process_data_file('dataset/locomo10.json')
memory_searcher = MemorySearch(output_file_path, args.top_k, args.filter_memories, args.is_graph)
memory_searcher.process_data_file("dataset/locomo10.json")
elif args.technique_type == "rag":
output_file_path = os.path.join(
args.output_folder,
f"rag_results_{args.chunk_size}_k{args.num_chunks}.json"
)
rag_manager = RAGManager(
data_path="dataset/locomo10_rag.json",
chunk_size=args.chunk_size,
k=args.num_chunks
)
output_file_path = os.path.join(args.output_folder, f"rag_results_{args.chunk_size}_k{args.num_chunks}.json")
rag_manager = RAGManager(data_path="dataset/locomo10_rag.json", chunk_size=args.chunk_size, k=args.num_chunks)
rag_manager.process_all_conversations(output_file_path)
elif args.technique_type == "langmem":
output_file_path = os.path.join(args.output_folder, "langmem_results.json")
@@ -85,11 +62,7 @@ def main():
elif args.method == "search":
output_file_path = os.path.join(args.output_folder, "zep_search_results.json")
zep_manager = ZepSearch()
zep_manager.process_data_file(
"dataset/locomo10.json",
"1",
output_file_path
)
zep_manager.process_data_file("dataset/locomo10.json", "1", output_file_path)
elif args.technique_type == "openai":
output_file_path = os.path.join(args.output_folder, "openai_results.json")
openai_manager = OpenAIPredict()

View File

@@ -28,14 +28,12 @@ def get_answer(question, speaker_1_user_id, speaker_1_memories, speaker_2_user_i
speaker_1_user_id=speaker_1_user_id,
speaker_1_memories=speaker_1_memories,
speaker_2_user_id=speaker_2_user_id,
speaker_2_memories=speaker_2_memories
speaker_2_memories=speaker_2_memories,
)
t1 = time.time()
response = client.chat.completions.create(
model=os.getenv("MODEL"),
messages=[{"role": "system", "content": prompt}],
temperature=0.0
model=os.getenv("MODEL"), messages=[{"role": "system", "content": prompt}], temperature=0.0
)
t2 = time.time()
return response.choices[0].message.content, t2 - t1
@@ -59,7 +57,9 @@ def prompt(state):
class LangMem:
def __init__(self,):
def __init__(
self,
):
self.store = InMemoryStore(
index={
"dims": 1536,
@@ -80,18 +80,12 @@ class LangMem:
)
def add_memory(self, message, config):
return self.agent.invoke(
{"messages": [{"role": "user", "content": message}]},
config=config
)
return self.agent.invoke({"messages": [{"role": "user", "content": message}]}, config=config)
def search_memory(self, query, config):
try:
t1 = time.time()
response = self.agent.invoke(
{"messages": [{"role": "user", "content": query}]},
config=config
)
response = self.agent.invoke({"messages": [{"role": "user", "content": query}]}, config=config)
t2 = time.time()
return response["messages"][-1].content, t2 - t1
except Exception as e:
@@ -102,7 +96,7 @@ class LangMem:
class LangMemManager:
def __init__(self, dataset_path):
self.dataset_path = dataset_path
with open(self.dataset_path, 'r') as f:
with open(self.dataset_path, "r") as f:
self.data = json.load(f)
def process_all_conversations(self, output_file_path):
@@ -123,7 +117,7 @@ class LangMemManager:
# Identify speakers
for conv in chat_history:
speakers.add(conv['speaker'])
speakers.add(conv["speaker"])
if len(speakers) != 2:
raise ValueError(f"Expected 2 speakers, got {len(speakers)}")
@@ -134,50 +128,52 @@ class LangMemManager:
# Add memories for each message
for conv in tqdm(chat_history, desc=f"Processing messages {key}", leave=False):
message = f"{conv['timestamp']} | {conv['speaker']}: {conv['text']}"
if conv['speaker'] == speaker1:
if conv["speaker"] == speaker1:
agent1.add_memory(message, config)
elif conv['speaker'] == speaker2:
elif conv["speaker"] == speaker2:
agent2.add_memory(message, config)
else:
raise ValueError(f"Expected speaker1 or speaker2, got {conv['speaker']}")
# Process questions
for q in tqdm(questions, desc=f"Processing questions {key}", leave=False):
category = q['category']
category = q["category"]
if int(category) == 5:
continue
answer = q['answer']
question = q['question']
answer = q["answer"]
question = q["question"]
response1, speaker1_memory_time = agent1.search_memory(question, config)
response2, speaker2_memory_time = agent2.search_memory(question, config)
generated_answer, response_time = get_answer(
question, speaker1, response1, speaker2, response2
)
generated_answer, response_time = get_answer(question, speaker1, response1, speaker2, response2)
result[key].append({
"question": question,
"answer": answer,
"response1": response1,
"response2": response2,
"category": category,
"speaker1_memory_time": speaker1_memory_time,
"speaker2_memory_time": speaker2_memory_time,
"response_time": response_time,
'response': generated_answer
})
result[key].append(
{
"question": question,
"answer": answer,
"response1": response1,
"response2": response2,
"category": category,
"speaker1_memory_time": speaker1_memory_time,
"speaker2_memory_time": speaker2_memory_time,
"response_time": response_time,
"response": generated_answer,
}
)
return result
# Use multiprocessing to process conversations in parallel
with mp.Pool(processes=10) as pool:
results = list(tqdm(
pool.imap(process_conversation, list(self.data.items())),
total=len(self.data),
desc="Processing conversations"
))
results = list(
tqdm(
pool.imap(process_conversation, list(self.data.items())),
total=len(self.data),
desc="Processing conversations",
)
)
# Combine results from all workers
for result in results:
@@ -185,5 +181,5 @@ class LangMemManager:
OUTPUT[key].extend(items)
# Save final results
with open(output_file_path, 'w') as f:
with open(output_file_path, "w") as f:
json.dump(OUTPUT, f, indent=4)

View File

@@ -13,7 +13,7 @@ load_dotenv()
# Update custom instructions
custom_instructions ="""
custom_instructions = """
Generate personal memories that follow these guidelines:
1. Each memory should be self-contained with complete context, including:
@@ -47,7 +47,7 @@ class MemoryADD:
self.mem0_client = MemoryClient(
api_key=os.getenv("MEM0_API_KEY"),
org_id=os.getenv("MEM0_ORGANIZATION_ID"),
project_id=os.getenv("MEM0_PROJECT_ID")
project_id=os.getenv("MEM0_PROJECT_ID"),
)
self.mem0_client.update_project(custom_instructions=custom_instructions)
@@ -59,15 +59,16 @@ class MemoryADD:
self.load_data()
def load_data(self):
with open(self.data_path, 'r') as f:
with open(self.data_path, "r") as f:
self.data = json.load(f)
return self.data
def add_memory(self, user_id, message, metadata, retries=3):
for attempt in range(retries):
try:
_ = self.mem0_client.add(message, user_id=user_id, version="v2",
metadata=metadata, enable_graph=self.is_graph)
_ = self.mem0_client.add(
message, user_id=user_id, version="v2", metadata=metadata, enable_graph=self.is_graph
)
return
except Exception as e:
if attempt < retries - 1:
@@ -78,13 +79,13 @@ class MemoryADD:
def add_memories_for_speaker(self, speaker, messages, timestamp, desc):
for i in tqdm(range(0, len(messages), self.batch_size), desc=desc):
batch_messages = messages[i:i+self.batch_size]
batch_messages = messages[i : i + self.batch_size]
self.add_memory(speaker, batch_messages, metadata={"timestamp": timestamp})
def process_conversation(self, item, idx):
conversation = item['conversation']
speaker_a = conversation['speaker_a']
speaker_b = conversation['speaker_b']
conversation = item["conversation"]
speaker_a = conversation["speaker_a"]
speaker_b = conversation["speaker_b"]
speaker_a_user_id = f"{speaker_a}_{idx}"
speaker_b_user_id = f"{speaker_b}_{idx}"
@@ -94,7 +95,7 @@ class MemoryADD:
self.mem0_client.delete_all(user_id=speaker_b_user_id)
for key in conversation.keys():
if key in ['speaker_a', 'speaker_b'] or "date" in key or "timestamp" in key:
if key in ["speaker_a", "speaker_b"] or "date" in key or "timestamp" in key:
continue
date_time_key = key + "_date_time"
@@ -104,10 +105,10 @@ class MemoryADD:
messages = []
messages_reverse = []
for chat in chats:
if chat['speaker'] == speaker_a:
if chat["speaker"] == speaker_a:
messages.append({"role": "user", "content": f"{speaker_a}: {chat['text']}"})
messages_reverse.append({"role": "assistant", "content": f"{speaker_a}: {chat['text']}"})
elif chat['speaker'] == speaker_b:
elif chat["speaker"] == speaker_b:
messages.append({"role": "assistant", "content": f"{speaker_b}: {chat['text']}"})
messages_reverse.append({"role": "user", "content": f"{speaker_b}: {chat['text']}"})
else:
@@ -116,11 +117,11 @@ class MemoryADD:
# add memories for the two users on different threads
thread_a = threading.Thread(
target=self.add_memories_for_speaker,
args=(speaker_a_user_id, messages, timestamp, "Adding Memories for Speaker A")
args=(speaker_a_user_id, messages, timestamp, "Adding Memories for Speaker A"),
)
thread_b = threading.Thread(
target=self.add_memories_for_speaker,
args=(speaker_b_user_id, messages_reverse, timestamp, "Adding Memories for Speaker B")
args=(speaker_b_user_id, messages_reverse, timestamp, "Adding Memories for Speaker B"),
)
thread_a.start()
@@ -134,10 +135,7 @@ class MemoryADD:
if not self.data:
raise ValueError("No data loaded. Please set data_path and call load_data() first.")
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [
executor.submit(self.process_conversation, item, idx)
for idx, item in enumerate(self.data)
]
futures = [executor.submit(self.process_conversation, item, idx) for idx, item in enumerate(self.data)]
for future in futures:
future.result()
future.result()

View File

@@ -16,12 +16,11 @@ load_dotenv()
class MemorySearch:
def __init__(self, output_path='results.json', top_k=10, filter_memories=False, is_graph=False):
def __init__(self, output_path="results.json", top_k=10, filter_memories=False, is_graph=False):
self.mem0_client = MemoryClient(
api_key=os.getenv("MEM0_API_KEY"),
org_id=os.getenv("MEM0_ORGANIZATION_ID"),
project_id=os.getenv("MEM0_PROJECT_ID")
project_id=os.getenv("MEM0_PROJECT_ID"),
)
self.top_k = top_k
self.openai_client = OpenAI()
@@ -42,11 +41,18 @@ class MemorySearch:
try:
if self.is_graph:
print("Searching with graph")
memories = self.mem0_client.search(query, user_id=user_id, top_k=self.top_k,
filter_memories=self.filter_memories, enable_graph=True, output_format='v1.1')
memories = self.mem0_client.search(
query,
user_id=user_id,
top_k=self.top_k,
filter_memories=self.filter_memories,
enable_graph=True,
output_format="v1.1",
)
else:
memories = self.mem0_client.search(query, user_id=user_id, top_k=self.top_k,
filter_memories=self.filter_memories)
memories = self.mem0_client.search(
query, user_id=user_id, top_k=self.top_k, filter_memories=self.filter_memories
)
break
except Exception as e:
print("Retrying...")
@@ -57,64 +63,86 @@ class MemorySearch:
end_time = time.time()
if not self.is_graph:
semantic_memories = [{'memory': memory['memory'],
'timestamp': memory['metadata']['timestamp'],
'score': round(memory['score'], 2)}
for memory in memories]
semantic_memories = [
{
"memory": memory["memory"],
"timestamp": memory["metadata"]["timestamp"],
"score": round(memory["score"], 2),
}
for memory in memories
]
graph_memories = None
else:
semantic_memories = [{'memory': memory['memory'],
'timestamp': memory['metadata']['timestamp'],
'score': round(memory['score'], 2)} for memory in memories['results']]
graph_memories = [{"source": relation['source'], "relationship": relation['relationship'], "target": relation['target']} for relation in memories['relations']]
semantic_memories = [
{
"memory": memory["memory"],
"timestamp": memory["metadata"]["timestamp"],
"score": round(memory["score"], 2),
}
for memory in memories["results"]
]
graph_memories = [
{"source": relation["source"], "relationship": relation["relationship"], "target": relation["target"]}
for relation in memories["relations"]
]
return semantic_memories, graph_memories, end_time - start_time
def answer_question(self, speaker_1_user_id, speaker_2_user_id, question, answer, category):
speaker_1_memories, speaker_1_graph_memories, speaker_1_memory_time = self.search_memory(speaker_1_user_id, question)
speaker_2_memories, speaker_2_graph_memories, speaker_2_memory_time = self.search_memory(speaker_2_user_id, question)
speaker_1_memories, speaker_1_graph_memories, speaker_1_memory_time = self.search_memory(
speaker_1_user_id, question
)
speaker_2_memories, speaker_2_graph_memories, speaker_2_memory_time = self.search_memory(
speaker_2_user_id, question
)
search_1_memory = [f"{item['timestamp']}: {item['memory']}"
for item in speaker_1_memories]
search_2_memory = [f"{item['timestamp']}: {item['memory']}"
for item in speaker_2_memories]
search_1_memory = [f"{item['timestamp']}: {item['memory']}" for item in speaker_1_memories]
search_2_memory = [f"{item['timestamp']}: {item['memory']}" for item in speaker_2_memories]
template = Template(self.ANSWER_PROMPT)
answer_prompt = template.render(
speaker_1_user_id=speaker_1_user_id.split('_')[0],
speaker_2_user_id=speaker_2_user_id.split('_')[0],
speaker_1_user_id=speaker_1_user_id.split("_")[0],
speaker_2_user_id=speaker_2_user_id.split("_")[0],
speaker_1_memories=json.dumps(search_1_memory, indent=4),
speaker_2_memories=json.dumps(search_2_memory, indent=4),
speaker_1_graph_memories=json.dumps(speaker_1_graph_memories, indent=4),
speaker_2_graph_memories=json.dumps(speaker_2_graph_memories, indent=4),
question=question
question=question,
)
t1 = time.time()
response = self.openai_client.chat.completions.create(
model=os.getenv("MODEL"),
messages=[
{"role": "system", "content": answer_prompt}
],
temperature=0.0
model=os.getenv("MODEL"), messages=[{"role": "system", "content": answer_prompt}], temperature=0.0
)
t2 = time.time()
response_time = t2 - t1
return response.choices[0].message.content, speaker_1_memories, speaker_2_memories, speaker_1_memory_time, speaker_2_memory_time, speaker_1_graph_memories, speaker_2_graph_memories, response_time
return (
response.choices[0].message.content,
speaker_1_memories,
speaker_2_memories,
speaker_1_memory_time,
speaker_2_memory_time,
speaker_1_graph_memories,
speaker_2_graph_memories,
response_time,
)
def process_question(self, val, speaker_a_user_id, speaker_b_user_id):
question = val.get('question', '')
answer = val.get('answer', '')
category = val.get('category', -1)
evidence = val.get('evidence', [])
adversarial_answer = val.get('adversarial_answer', '')
question = val.get("question", "")
answer = val.get("answer", "")
category = val.get("category", -1)
evidence = val.get("evidence", [])
adversarial_answer = val.get("adversarial_answer", "")
response, speaker_1_memories, speaker_2_memories, speaker_1_memory_time, speaker_2_memory_time, speaker_1_graph_memories, speaker_2_graph_memories, response_time = self.answer_question(
speaker_a_user_id,
speaker_b_user_id,
question,
answer,
category
)
(
response,
speaker_1_memories,
speaker_2_memories,
speaker_1_memory_time,
speaker_2_memory_time,
speaker_1_graph_memories,
speaker_2_graph_memories,
response_time,
) = self.answer_question(speaker_a_user_id, speaker_b_user_id, question, answer, category)
result = {
"question": question,
@@ -125,67 +153,63 @@ class MemorySearch:
"adversarial_answer": adversarial_answer,
"speaker_1_memories": speaker_1_memories,
"speaker_2_memories": speaker_2_memories,
'num_speaker_1_memories': len(speaker_1_memories),
'num_speaker_2_memories': len(speaker_2_memories),
'speaker_1_memory_time': speaker_1_memory_time,
'speaker_2_memory_time': speaker_2_memory_time,
"num_speaker_1_memories": len(speaker_1_memories),
"num_speaker_2_memories": len(speaker_2_memories),
"speaker_1_memory_time": speaker_1_memory_time,
"speaker_2_memory_time": speaker_2_memory_time,
"speaker_1_graph_memories": speaker_1_graph_memories,
"speaker_2_graph_memories": speaker_2_graph_memories,
"response_time": response_time
"response_time": response_time,
}
# Save results after each question is processed
with open(self.output_path, 'w') as f:
with open(self.output_path, "w") as f:
json.dump(self.results, f, indent=4)
return result
def process_data_file(self, file_path):
with open(file_path, 'r') as f:
with open(file_path, "r") as f:
data = json.load(f)
for idx, item in tqdm(enumerate(data), total=len(data), desc="Processing conversations"):
qa = item['qa']
conversation = item['conversation']
speaker_a = conversation['speaker_a']
speaker_b = conversation['speaker_b']
qa = item["qa"]
conversation = item["conversation"]
speaker_a = conversation["speaker_a"]
speaker_b = conversation["speaker_b"]
speaker_a_user_id = f"{speaker_a}_{idx}"
speaker_b_user_id = f"{speaker_b}_{idx}"
for question_item in tqdm(qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False):
result = self.process_question(
question_item,
speaker_a_user_id,
speaker_b_user_id
)
for question_item in tqdm(
qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False
):
result = self.process_question(question_item, speaker_a_user_id, speaker_b_user_id)
self.results[idx].append(result)
# Save results after each question is processed
with open(self.output_path, 'w') as f:
with open(self.output_path, "w") as f:
json.dump(self.results, f, indent=4)
# Final save at the end
with open(self.output_path, 'w') as f:
with open(self.output_path, "w") as f:
json.dump(self.results, f, indent=4)
def process_questions_parallel(self, qa_list, speaker_a_user_id, speaker_b_user_id, max_workers=1):
def process_single_question(val):
result = self.process_question(val, speaker_a_user_id, speaker_b_user_id)
# Save results after each question is processed
with open(self.output_path, 'w') as f:
with open(self.output_path, "w") as f:
json.dump(self.results, f, indent=4)
return result
with ThreadPoolExecutor(max_workers=max_workers) as executor:
results = list(tqdm(
executor.map(process_single_question, qa_list),
total=len(qa_list),
desc="Answering Questions"
))
results = list(
tqdm(executor.map(process_single_question, qa_list), total=len(qa_list), desc="Answering Questions")
)
# Final save at the end
with open(self.output_path, 'w') as f:
with open(self.output_path, "w") as f:
json.dump(self.results, f, indent=4)
return results

View File

@@ -59,23 +59,19 @@ class OpenAIPredict:
self.results = defaultdict(list)
def search_memory(self, idx):
with open(f'memories/{idx}.txt', 'r') as file:
with open(f"memories/{idx}.txt", "r") as file:
memories = file.read()
return memories, 0
def process_question(self, val, idx):
question = val.get('question', '')
answer = val.get('answer', '')
category = val.get('category', -1)
evidence = val.get('evidence', [])
adversarial_answer = val.get('adversarial_answer', '')
question = val.get("question", "")
answer = val.get("answer", "")
category = val.get("category", -1)
evidence = val.get("evidence", [])
adversarial_answer = val.get("adversarial_answer", "")
response, search_memory_time, response_time, context = self.answer_question(
idx,
question
)
response, search_memory_time, response_time, context = self.answer_question(idx, question)
result = {
"question": question,
@@ -86,7 +82,7 @@ class OpenAIPredict:
"adversarial_answer": adversarial_answer,
"search_memory_time": search_memory_time,
"response_time": response_time,
"context": context
"context": context,
}
return result
@@ -95,43 +91,35 @@ class OpenAIPredict:
memories, search_memory_time = self.search_memory(idx)
template = Template(ANSWER_PROMPT)
answer_prompt = template.render(
memories=memories,
question=question
)
answer_prompt = template.render(memories=memories, question=question)
t1 = time.time()
response = self.openai_client.chat.completions.create(
model=os.getenv("MODEL"),
messages=[
{"role": "system", "content": answer_prompt}
],
temperature=0.0
model=os.getenv("MODEL"), messages=[{"role": "system", "content": answer_prompt}], temperature=0.0
)
t2 = time.time()
response_time = t2 - t1
return response.choices[0].message.content, search_memory_time, response_time, memories
def process_data_file(self, file_path, output_file_path):
with open(file_path, 'r') as f:
with open(file_path, "r") as f:
data = json.load(f)
for idx, item in tqdm(enumerate(data), total=len(data), desc="Processing conversations"):
qa = item['qa']
qa = item["qa"]
for question_item in tqdm(qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False):
result = self.process_question(
question_item,
idx
)
for question_item in tqdm(
qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False
):
result = self.process_question(question_item, idx)
self.results[idx].append(result)
# Save results after each question is processed
with open(output_file_path, 'w') as f:
with open(output_file_path, "w") as f:
json.dump(self.results, f, indent=4)
# Final save at the end
with open(output_file_path, 'w') as f:
with open(output_file_path, "w") as f:
json.dump(self.results, f, indent=4)
@@ -141,4 +129,3 @@ if __name__ == "__main__":
args = parser.parse_args()
openai_predict = OpenAIPredict()
openai_predict.process_data_file("../../dataset/locomo10.json", args.output_file_path)

View File

@@ -33,10 +33,7 @@ class RAGManager:
def generate_response(self, question, context):
template = Template(PROMPT)
prompt = template.render(
CONTEXT=context,
QUESTION=question
)
prompt = template.render(CONTEXT=context, QUESTION=question)
max_retries = 3
retries = 0
@@ -47,19 +44,21 @@ class RAGManager:
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system",
"content": "You are a helpful assistant that can answer "
"questions based on the provided context."
"If the question involves timing, use the conversation date for reference."
"Provide the shortest possible answer."
"Use words directly from the conversation when possible."
"Avoid using subjects in your answer."},
{"role": "user", "content": prompt}
{
"role": "system",
"content": "You are a helpful assistant that can answer "
"questions based on the provided context."
"If the question involves timing, use the conversation date for reference."
"Provide the shortest possible answer."
"Use words directly from the conversation when possible."
"Avoid using subjects in your answer.",
},
{"role": "user", "content": prompt},
],
temperature=0
temperature=0,
)
t2 = time.time()
return response.choices[0].message.content.strip(), t2-t1
return response.choices[0].message.content.strip(), t2 - t1
except Exception as e:
retries += 1
if retries > max_retries:
@@ -69,21 +68,16 @@ class RAGManager:
def clean_chat_history(self, chat_history):
cleaned_chat_history = ""
for c in chat_history:
cleaned_chat_history += (f"{c['timestamp']} | {c['speaker']}: "
f"{c['text']}\n")
cleaned_chat_history += f"{c['timestamp']} | {c['speaker']}: " f"{c['text']}\n"
return cleaned_chat_history
def calculate_embedding(self, document):
response = self.client.embeddings.create(
model=os.getenv("EMBEDDING_MODEL"),
input=document
)
response = self.client.embeddings.create(model=os.getenv("EMBEDDING_MODEL"), input=document)
return response.data[0].embedding
def calculate_similarity(self, embedding1, embedding2):
return np.dot(embedding1, embedding2) / (
np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
return np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
def search(self, query, chunks, embeddings, k=1):
"""
@@ -101,10 +95,7 @@ class RAGManager:
"""
t1 = time.time()
query_embedding = self.calculate_embedding(query)
similarities = [
self.calculate_similarity(query_embedding, embedding)
for embedding in embeddings
]
similarities = [self.calculate_similarity(query_embedding, embedding) for embedding in embeddings]
# Get indices of top-k most similar chunks
if k == 1:
@@ -118,7 +109,7 @@ class RAGManager:
combined_chunks = "\n<->\n".join([chunks[i] for i in top_indices])
t2 = time.time()
return combined_chunks, t2-t1
return combined_chunks, t2 - t1
def create_chunks(self, chat_history, chunk_size=500):
"""
@@ -139,7 +130,7 @@ class RAGManager:
# Split into chunks based on token count
for i in range(0, len(tokens), chunk_size):
chunk_tokens = tokens[i:i+chunk_size]
chunk_tokens = tokens[i : i + chunk_size]
chunk = encoding.decode(chunk_tokens)
chunks.append(chunk)
@@ -159,13 +150,9 @@ class RAGManager:
chat_history = value["conversation"]
questions = value["question"]
chunks, embeddings = self.create_chunks(
chat_history, self.chunk_size
)
chunks, embeddings = self.create_chunks(chat_history, self.chunk_size)
for item in tqdm(
questions, desc="Answering questions", leave=False
):
for item in tqdm(questions, desc="Answering questions", leave=False):
question = item["question"]
answer = item.get("answer", "")
category = item["category"]
@@ -174,22 +161,20 @@ class RAGManager:
context = chunks[0]
search_time = 0
else:
context, search_time = self.search(
question, chunks, embeddings, k=self.k
)
response, response_time = self.generate_response(
question, context
)
context, search_time = self.search(question, chunks, embeddings, k=self.k)
response, response_time = self.generate_response(question, context)
FINAL_RESULTS[key].append({
"question": question,
"answer": answer,
"category": category,
"context": context,
"response": response,
"search_time": search_time,
"response_time": response_time,
})
FINAL_RESULTS[key].append(
{
"question": question,
"answer": answer,
"category": category,
"context": context,
"response": response,
"search_time": search_time,
"response_time": response_time,
}
)
with open(output_file_path, "w+") as f:
json.dump(FINAL_RESULTS, f, indent=4)

View File

@@ -1,12 +1,3 @@
TECHNIQUES = [
"mem0",
"rag",
"langmem",
"zep",
"openai"
]
TECHNIQUES = ["mem0", "rag", "langmem", "zep", "openai"]
METHODS = [
"add",
"search"
]
METHODS = ["add", "search"]

View File

@@ -19,12 +19,12 @@ class ZepAdd:
self.load_data()
def load_data(self):
with open(self.data_path, 'r') as f:
with open(self.data_path, "r") as f:
self.data = json.load(f)
return self.data
def process_conversation(self, run_id, item, idx):
conversation = item['conversation']
conversation = item["conversation"]
user_id = f"run_id_{run_id}_experiment_user_{idx}"
session_id = f"run_id_{run_id}_experiment_session_{idx}"
@@ -41,7 +41,7 @@ class ZepAdd:
print("Starting to add memories... for user", user_id)
for key in tqdm(conversation.keys(), desc=f"Processing user {user_id}"):
if key in ['speaker_a', 'speaker_b'] or "date" in key:
if key in ["speaker_a", "speaker_b"] or "date" in key:
continue
date_time_key = key + "_date_time"
@@ -51,11 +51,13 @@ class ZepAdd:
for chat in tqdm(chats, desc=f"Adding chats for {key}", leave=False):
self.zep_client.memory.add(
session_id=session_id,
messages=[Message(
role=chat['speaker'],
role_type="user",
content=f"{timestamp}: {chat['text']}",
)]
messages=[
Message(
role=chat["speaker"],
role_type="user",
content=f"{timestamp}: {chat['text']}",
)
],
)
def process_all_conversations(self, run_id):
@@ -71,4 +73,4 @@ if __name__ == "__main__":
parser.add_argument("--run_id", type=str, required=True)
args = parser.parse_args()
zep_add = ZepAdd(data_path="../../dataset/locomo10.json")
zep_add.process_all_conversations(args.run_id)
zep_add.process_all_conversations(args.run_id)

View File

@@ -42,9 +42,9 @@ class ZepSearch:
return f"{edge.valid_at if edge.valid_at else 'date unknown'} - {(edge.invalid_at if edge.invalid_at else 'present')}"
def compose_search_context(self, edges: list[EntityEdge], nodes: list[EntityNode]) -> str:
facts = [f' - {edge.fact} ({self.format_edge_date_range(edge)})' for edge in edges]
entities = [f' - {node.name}: {node.summary}' for node in nodes]
return TEMPLATE.format(facts='\n'.join(facts), entities='\n'.join(entities))
facts = [f" - {edge.fact} ({self.format_edge_date_range(edge)})" for edge in edges]
entities = [f" - {node.name}: {node.summary}" for node in nodes]
return TEMPLATE.format(facts="\n".join(facts), entities="\n".join(entities))
def search_memory(self, run_id, idx, query, max_retries=3, retry_delay=1):
start_time = time.time()
@@ -52,8 +52,14 @@ class ZepSearch:
while retries < max_retries:
try:
user_id = f"run_id_{run_id}_experiment_user_{idx}"
edges_results = (self.zep_client.graph.search(user_id=user_id, reranker='cross_encoder', query=query, scope='edges', limit=20)).edges
node_results = (self.zep_client.graph.search(user_id=user_id, reranker='rrf', query=query, scope='nodes', limit=20)).nodes
edges_results = (
self.zep_client.graph.search(
user_id=user_id, reranker="cross_encoder", query=query, scope="edges", limit=20
)
).edges
node_results = (
self.zep_client.graph.search(user_id=user_id, reranker="rrf", query=query, scope="nodes", limit=20)
).nodes
context = self.compose_search_context(edges_results, node_results)
break
except Exception as e:
@@ -68,17 +74,13 @@ class ZepSearch:
return context, end_time - start_time
def process_question(self, run_id, val, idx):
question = val.get('question', '')
answer = val.get('answer', '')
category = val.get('category', -1)
evidence = val.get('evidence', [])
adversarial_answer = val.get('adversarial_answer', '')
question = val.get("question", "")
answer = val.get("answer", "")
category = val.get("category", -1)
evidence = val.get("evidence", [])
adversarial_answer = val.get("adversarial_answer", "")
response, search_memory_time, response_time, context = self.answer_question(
run_id,
idx,
question
)
response, search_memory_time, response_time, context = self.answer_question(run_id, idx, question)
result = {
"question": question,
@@ -89,7 +91,7 @@ class ZepSearch:
"adversarial_answer": adversarial_answer,
"search_memory_time": search_memory_time,
"response_time": response_time,
"context": context
"context": context,
}
return result
@@ -98,44 +100,35 @@ class ZepSearch:
context, search_memory_time = self.search_memory(run_id, idx, question)
template = Template(ANSWER_PROMPT_ZEP)
answer_prompt = template.render(
memories=context,
question=question
)
answer_prompt = template.render(memories=context, question=question)
t1 = time.time()
response = self.openai_client.chat.completions.create(
model=os.getenv("MODEL"),
messages=[
{"role": "system", "content": answer_prompt}
],
temperature=0.0
model=os.getenv("MODEL"), messages=[{"role": "system", "content": answer_prompt}], temperature=0.0
)
t2 = time.time()
response_time = t2 - t1
return response.choices[0].message.content, search_memory_time, response_time, context
def process_data_file(self, file_path, run_id, output_file_path):
with open(file_path, 'r') as f:
with open(file_path, "r") as f:
data = json.load(f)
for idx, item in tqdm(enumerate(data), total=len(data), desc="Processing conversations"):
qa = item['qa']
qa = item["qa"]
for question_item in tqdm(qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False):
result = self.process_question(
run_id,
question_item,
idx
)
for question_item in tqdm(
qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False
):
result = self.process_question(run_id, question_item, idx)
self.results[idx].append(result)
# Save results after each question is processed
with open(output_file_path, 'w') as f:
with open(output_file_path, "w") as f:
json.dump(self.results, f, indent=4)
# Final save at the end
with open(output_file_path, 'w') as f:
with open(output_file_path, "w") as f:
json.dump(self.results, f, indent=4)

View File

@@ -56,9 +56,7 @@
"\n",
"import os\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = (\n",
" \"\"\n",
")"
"os.environ[\"OPENAI_API_KEY\"] = \"\""
]
},
{
@@ -149,7 +147,7 @@
" \"role\": \"assistant\",\n",
" \"content\": \"Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future.\",\n",
" },\n",
"]\n"
"]"
]
},
{
@@ -166,9 +164,7 @@
"outputs": [],
"source": [
"# Store inferred memories (default behavior)\n",
"result = m.add(\n",
" messages, user_id=\"alice\", metadata={\"category\": \"movie_recommendations\"}\n",
")"
"result = m.add(messages, user_id=\"alice\", metadata={\"category\": \"movie_recommendations\"})"
]
},
{

View File

@@ -20,19 +20,19 @@ agent = Agent(
name="Fitness Agent",
model=OpenAIChat(id="gpt-4o"),
description="You are a helpful fitness assistant who remembers past logs and gives personalized suggestions for Anish's training and diet.",
markdown=True
markdown=True,
)
# Store user preferences as memory
def store_user_preferences(conversation: list, user_id: str = USER_ID):
"""Store user preferences from conversation history"""
memory_client.add(conversation, user_id=user_id, output_format='v1.1')
memory_client.add(conversation, user_id=user_id, output_format="v1.1")
# Memory-aware assistant function
def fitness_coach(user_input: str, user_id: str = USER_ID):
memories = memory_client.search(user_input, user_id=user_id) # Search relevant memories bases on user query
memories = memory_client.search(user_input, user_id=user_id) # Search relevant memories bases on user query
memory_context = "\n".join(f"- {m['memory']}" for m in memories)
prompt = f"""You are a fitness assistant who helps Anish with his training, recovery, and diet. You have long-term memory of his health, routines, preferences, and past conversations.
@@ -48,113 +48,66 @@ User query:
memory_client.add(f"User: {user_input}\nAssistant: {response.content}", user_id=user_id)
return response.content
# --------------------------------------------------
# Store user preferences and memories
messages = [
{
"role": "user",
"content": "Hi, Im Anish. I'm 26 years old, 5'10\", and weigh 72kg. I started working out 6 months ago with the goal of building lean muscle."
"content": "Hi, Im Anish. I'm 26 years old, 5'10\", and weigh 72kg. I started working out 6 months ago with the goal of building lean muscle.",
},
{
"role": "assistant",
"content": "Got it — you're 26, 5'10\", 72kg, and on a lean muscle journey. Started gym 6 months ago."
"content": "Got it — you're 26, 5'10\", 72kg, and on a lean muscle journey. Started gym 6 months ago.",
},
{
"role": "user",
"content": "I follow a push-pull-legs routine and train 5 times a week. My rest days are Wednesday and Sunday."
"content": "I follow a push-pull-legs routine and train 5 times a week. My rest days are Wednesday and Sunday.",
},
{
"role": "assistant",
"content": "Understood — push-pull-legs split, training 5x/week with rest on Wednesdays and Sundays."
"content": "Understood — push-pull-legs split, training 5x/week with rest on Wednesdays and Sundays.",
},
{"role": "user", "content": "After push days, I usually eat high-protein and moderate-carb meals to recover."},
{"role": "assistant", "content": "Noted — high-protein, moderate-carb meals after push workouts."},
{"role": "user", "content": "For pull days, I take whey protein and eat a banana after training."},
{"role": "assistant", "content": "Logged — whey protein and banana post pull workouts."},
{"role": "user", "content": "On leg days, I make sure to have complex carbs like rice or oats."},
{"role": "assistant", "content": "Noted — complex carbs like rice and oats are part of your leg day meals."},
{
"role": "user",
"content": "After push days, I usually eat high-protein and moderate-carb meals to recover."
},
{
"role": "assistant",
"content": "Noted — high-protein, moderate-carb meals after push workouts."
"content": "I often feel sore after leg days, so I use turmeric milk and magnesium to help with recovery.",
},
{"role": "assistant", "content": "I'll remember turmeric milk and magnesium as part of your leg day recovery."},
{
"role": "user",
"content": "For pull days, I take whey protein and eat a banana after training."
"content": "Last push day, I did 3x8 bench press at 60kg, 4x12 overhead press, and dips. Felt fatigued after.",
},
{
"role": "assistant",
"content": "Logged — whey protein and banana post pull workouts."
"content": "Push day logged — 60kg bench, overhead press, dips. You felt fatigued afterward.",
},
{"role": "user", "content": "I prefer light dinners post-workout like tofu, soup, and vegetables."},
{"role": "assistant", "content": "Got it — light dinners post-workout: tofu, soup, and veggies."},
{
"role": "user",
"content": "On leg days, I make sure to have complex carbs like rice or oats."
},
{
"role": "assistant",
"content": "Noted — complex carbs like rice and oats are part of your leg day meals."
"content": "I have mild lactose intolerance, so I avoid dairy. I use almond milk or lactose-free whey.",
},
{"role": "assistant", "content": "Understood — avoiding regular dairy, using almond milk and lactose-free whey."},
{
"role": "user",
"content": "I often feel sore after leg days, so I use turmeric milk and magnesium to help with recovery."
"content": "I get occasional knee pain, so I avoid deep squats and do more hamstring curls and glute bridges on leg days.",
},
{
"role": "assistant",
"content": "I'll remember turmeric milk and magnesium as part of your leg day recovery."
},
{
"role": "user",
"content": "Last push day, I did 3x8 bench press at 60kg, 4x12 overhead press, and dips. Felt fatigued after."
},
{
"role": "assistant",
"content": "Push day logged — 60kg bench, overhead press, dips. You felt fatigued afterward."
},
{
"role": "user",
"content": "I prefer light dinners post-workout like tofu, soup, and vegetables."
},
{
"role": "assistant",
"content": "Got it — light dinners post-workout: tofu, soup, and veggies."
},
{
"role": "user",
"content": "I have mild lactose intolerance, so I avoid dairy. I use almond milk or lactose-free whey."
},
{
"role": "assistant",
"content": "Understood — avoiding regular dairy, using almond milk and lactose-free whey."
},
{
"role": "user",
"content": "I get occasional knee pain, so I avoid deep squats and do more hamstring curls and glute bridges on leg days."
},
{
"role": "assistant",
"content": "Noted — due to knee discomfort, you substitute deep squats with curls and glute bridges."
},
{
"role": "user",
"content": "I track sleep and notice poor performance when I sleep less than 6 hours."
},
{
"role": "assistant",
"content": "Logged — performance drops when you get under 6 hours of sleep."
},
{
"role": "user",
"content": "I take magnesium supplements to help with muscle recovery and sleep quality."
},
{
"role": "assistant",
"content": "Remembered — magnesium helps you with recovery and sleep."
},
{
"role": "user",
"content": "I avoid caffeine after 4 PM because it affects my sleep."
},
{
"role": "assistant",
"content": "Got it — you avoid caffeine post-4 PM to protect your sleep."
"content": "Noted — due to knee discomfort, you substitute deep squats with curls and glute bridges.",
},
{"role": "user", "content": "I track sleep and notice poor performance when I sleep less than 6 hours."},
{"role": "assistant", "content": "Logged — performance drops when you get under 6 hours of sleep."},
{"role": "user", "content": "I take magnesium supplements to help with muscle recovery and sleep quality."},
{"role": "assistant", "content": "Remembered — magnesium helps you with recovery and sleep."},
{"role": "user", "content": "I avoid caffeine after 4 PM because it affects my sleep."},
{"role": "assistant", "content": "Got it — you avoid caffeine post-4 PM to protect your sleep."},
]
store_user_preferences(messages)

View File

@@ -1,9 +1,11 @@
import asyncio
import warnings
from google.adk.agents import Agent
from google.adk.sessions import InMemorySessionService
from google.adk.runners import Runner
from google.adk.sessions import InMemorySessionService
from google.genai import types
from mem0 import MemoryClient
warnings.filterwarnings("ignore", category=DeprecationWarning)
@@ -19,14 +21,14 @@ def save_patient_info(information: str) -> dict:
print(f"Storing patient information: {information[:30]}...")
# Get user_id from session state or use default
user_id = getattr(save_patient_info, 'user_id', 'default_user')
user_id = getattr(save_patient_info, "user_id", "default_user")
# Store in Mem0
response = mem0_client.add(
mem0_client.add(
[{"role": "user", "content": information}],
user_id=user_id,
run_id="healthcare_session",
metadata={"type": "patient_information"}
metadata={"type": "patient_information"},
)
return {"status": "success", "message": "Information saved"}
@@ -37,7 +39,7 @@ def retrieve_patient_info(query: str) -> str:
print(f"Searching for patient information: {query}")
# Get user_id from session state or use default
user_id = getattr(retrieve_patient_info, 'user_id', 'default_user')
user_id = getattr(retrieve_patient_info, "user_id", "default_user")
# Search Mem0
results = mem0_client.search(
@@ -45,7 +47,7 @@ def retrieve_patient_info(query: str) -> str:
user_id=user_id,
run_id="healthcare_session",
limit=5,
threshold=0.7 # Higher threshold for more relevant results
threshold=0.7, # Higher threshold for more relevant results
)
if not results:
@@ -65,7 +67,7 @@ def schedule_appointment(date: str, time: str, reason: str) -> dict:
"status": "success",
"appointment_id": appointment_id,
"confirmation": f"Appointment scheduled for {date} at {time} for {reason}",
"message": "Please arrive 15 minutes early to complete paperwork."
"message": "Please arrive 15 minutes early to complete paperwork.",
}
@@ -89,7 +91,7 @@ IMPORTANT GUIDELINES:
- For serious symptoms, always recommend consulting a healthcare professional.
- Keep all patient information confidential.
""",
tools=[save_patient_info, retrieve_patient_info, schedule_appointment]
tools=[save_patient_info, retrieve_patient_info, schedule_appointment],
)
# Set Up Session and Runner
@@ -101,18 +103,10 @@ USER_ID = "Alex"
SESSION_ID = "session_001"
# Create a session
session = session_service.create_session(
app_name=APP_NAME,
user_id=USER_ID,
session_id=SESSION_ID
)
session = session_service.create_session(app_name=APP_NAME, user_id=USER_ID, session_id=SESSION_ID)
# Create the runner
runner = Runner(
agent=healthcare_agent,
app_name=APP_NAME,
session_service=session_service
)
runner = Runner(agent=healthcare_agent, app_name=APP_NAME, session_service=session_service)
# Interact with the Healthcare Assistant
@@ -121,21 +115,14 @@ async def call_agent_async(query, runner, user_id, session_id):
print(f"\n>>> Patient: {query}")
# Format the user's message
content = types.Content(
role='user',
parts=[types.Part(text=query)]
)
content = types.Content(role="user", parts=[types.Part(text=query)])
# Set user_id for tools to access
save_patient_info.user_id = user_id
retrieve_patient_info.user_id = user_id
# Run the agent
async for event in runner.run_async(
user_id=user_id,
session_id=session_id,
new_message=content
):
async for event in runner.run_async(user_id=user_id, session_id=session_id, new_message=content):
if event.is_final_response():
if event.content and event.content.parts:
response = event.content.parts[0].text
@@ -152,7 +139,7 @@ async def run_conversation():
"Hi, I'm Alex. I've been having headaches for the past week, and I have a penicillin allergy.",
runner=runner,
user_id=USER_ID,
session_id=SESSION_ID
session_id=SESSION_ID,
)
# Request for health information
@@ -160,7 +147,7 @@ async def run_conversation():
"Can you tell me more about what might be causing my headaches?",
runner=runner,
user_id=USER_ID,
session_id=SESSION_ID
session_id=SESSION_ID,
)
# Schedule an appointment
@@ -168,15 +155,12 @@ async def run_conversation():
"I think I should see a doctor. Can you help me schedule an appointment for next Monday at 2pm?",
runner=runner,
user_id=USER_ID,
session_id=SESSION_ID
session_id=SESSION_ID,
)
# Test memory - should remember patient name, symptoms, and allergy
await call_agent_async(
"What medications should I avoid for my headaches?",
runner=runner,
user_id=USER_ID,
session_id=SESSION_ID
"What medications should I avoid for my headaches?", runner=runner, user_id=USER_ID, session_id=SESSION_ID
)
@@ -191,37 +175,28 @@ async def interactive_mode():
session_id = f"session_{hash(patient_id) % 1000:03d}"
# Create session for this user
session = session_service.create_session(
app_name=APP_NAME,
user_id=patient_id,
session_id=session_id
)
session_service.create_session(app_name=APP_NAME, user_id=patient_id, session_id=session_id)
print(f"\nStarting conversation with patient ID: {patient_id}")
print("Type your message and press Enter.")
while True:
user_input = input("\n>>> Patient: ").strip()
if user_input.lower() in ['exit', 'quit', 'bye']:
if user_input.lower() in ["exit", "quit", "bye"]:
print("Ending conversation. Thank you!")
break
await call_agent_async(
user_input,
runner=runner,
user_id=patient_id,
session_id=session_id
)
await call_agent_async(user_input, runner=runner, user_id=patient_id, session_id=session_id)
# Main execution
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Healthcare Assistant with Memory')
parser.add_argument('--demo', action='store_true', help='Run the demo conversation')
parser.add_argument('--interactive', action='store_true', help='Run in interactive mode')
parser.add_argument('--patient-id', type=str, default=USER_ID, help='Patient ID for the conversation')
parser = argparse.ArgumentParser(description="Healthcare Assistant with Memory")
parser.add_argument("--demo", action="store_true", help="Run the demo conversation")
parser.add_argument("--interactive", action="store_true", help="Run in interactive mode")
parser.add_argument("--patient-id", type=str, default=USER_ID, help="Patient ID for the conversation")
args = parser.parse_args()
if args.demo:
@@ -231,5 +206,3 @@ if __name__ == "__main__":
else:
# Default to demo mode if no arguments provided
asyncio.run(run_conversation())

View File

@@ -16,26 +16,21 @@ from mem0 import Memory
# Configure Mem0 with Grok 3 and Qdrant
config = {
"vector_store": {
"provider": "qdrant",
"config": {
"embedding_model_dims": 384
}
},
"vector_store": {"provider": "qdrant", "config": {"embedding_model_dims": 384}},
"llm": {
"provider": "xai",
"config": {
"model": "grok-3-beta",
"temperature": 0.1,
"max_tokens": 2000,
}
},
},
"embedder": {
"provider": "huggingface",
"config": {
"model": "all-MiniLM-L6-v2" # open embedding model
}
}
},
},
}
# Instantiate memory layer
@@ -57,20 +52,14 @@ def recommend_movie_with_memory(user_id: str, user_query: str):
prompt += f"\nPreviously, the user mentioned: {past_memories}"
# Generate movie recommendation using Grok 3
response = grok_client.chat.completions.create(
model="grok-3-beta",
messages=[
{"role": "user", "content": prompt}
]
)
response = grok_client.chat.completions.create(model="grok-3-beta", messages=[{"role": "user", "content": prompt}])
recommendation = response.choices[0].message.content
# Store conversation in memory
memory.add(
[{"role": "user", "content": user_query},
{"role": "assistant", "content": recommendation}],
[{"role": "user", "content": user_query}, {"role": "assistant", "content": recommendation}],
user_id=user_id,
metadata={"category": "movie"}
metadata={"category": "movie"},
)
return recommendation
@@ -81,10 +70,11 @@ if __name__ == "__main__":
user_id = "arshi"
recommend_movie_with_memory(user_id, "I'm looking for a movie to watch tonight. Any suggestions?")
# OUTPUT: You have watched Intersteller last weekend and you don't like horror movies, maybe you can watch "Purple Hearts" today.
recommend_movie_with_memory(user_id, "Can we skip the tearjerkers? I really enjoyed Notting Hill and Crazy Rich Asians.")
recommend_movie_with_memory(
user_id, "Can we skip the tearjerkers? I really enjoyed Notting Hill and Crazy Rich Asians."
)
# OUTPUT: Got it — no sad endings! You might enjoy "The Proposal" or "Love, Rosie". Theyre both light-hearted romcoms with happy vibes.
recommend_movie_with_memory(user_id, "Any light-hearted movie I can watch after work today?")
# OUTPUT: Since you liked Crazy Rich Asians and The Proposal, how about "The Intern" or "Isnt It Romantic"? Both are upbeat, funny, and perfect for relaxing.
recommend_movie_with_memory(user_id, "Ive already watched The Intern. Something new maybe?")
# OUTPUT: No problem! Try "Your Place or Mine" - romcoms that match your taste and are tear-free!

View File

@@ -23,8 +23,8 @@ agent = Agent(
name="Personal Agent",
model=OpenAIChat(id="gpt-4o"),
description="You are a helpful personal agent that helps me with day to day activities."
"You can process both text and images.",
markdown=True
"You can process both text and images.",
markdown=True,
)
@@ -35,24 +35,16 @@ def chat_user(user_input: str = None, user_id: str = "user_123", image_path: str
base64_image = base64.b64encode(image_file.read()).decode("utf-8")
# First: the text message
text_msg = {
"role": "user",
"content": user_input
}
text_msg = {"role": "user", "content": user_input}
# Second: the image message
image_msg = {
"role": "user",
"content": {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
"content": {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}},
}
# Send both as separate message objects
client.add([text_msg, image_msg], user_id=user_id, output_format='v1.1')
client.add([text_msg, image_msg], user_id=user_id, output_format="v1.1")
print("✅ Image uploaded and stored in memory.")
if user_input:
@@ -92,10 +84,13 @@ print(chat_user("When is my test?", user_id=user_id))
# OUTPUT: Your pilot's test is on your birthday, which is in five days. You're turning 25!
# Good luck with your preparations, and remember to take some time to relax amidst the studying.
print(chat_user("This is the picture of what I brought with me in the trip to Bahamas",
image_path="travel_items.jpeg", # this will be added to Mem0 memory
user_id=user_id))
print(chat_user("hey can you quickly tell me if brought my sunglasses to my trip, not able to find",
user_id=user_id))
print(
chat_user(
"This is the picture of what I brought with me in the trip to Bahamas",
image_path="travel_items.jpeg", # this will be added to Mem0 memory
user_id=user_id,
)
)
print(chat_user("hey can you quickly tell me if brought my sunglasses to my trip, not able to find", user_id=user_id))
# OUTPUT: Yes, you did bring your sunglasses on your trip to the Bahamas along with your laptop, face masks and other items..
# Since you can't find them now, perhaps check the pockets of jackets you wore or in your luggage compartments.

View File

@@ -7,6 +7,7 @@ In order to run this file, you need to set up your Mem0 API at Mem0 platform and
export OPENAI_API_KEY="your_openai_api_key"
export MEM0_API_KEY="your_mem0_api_key"
"""
import asyncio
from agents import Agent, Runner
@@ -23,25 +24,19 @@ study_agent = Agent(
- Identify topics the user has struggled with (e.g., "I'm confused", "this is hard")
- Help with spaced repetition by suggesting topics to revisit based on last review time
- Personalize answers using stored memories
- Summarize PDFs or notes the user uploads""")
- Summarize PDFs or notes the user uploads""",
)
# Upload and store PDF to Mem0
def upload_pdf(pdf_url: str, user_id: str):
pdf_message = {
"role": "user",
"content": {
"type": "pdf_url",
"pdf_url": {"url": pdf_url}
}
}
pdf_message = {"role": "user", "content": {"type": "pdf_url", "pdf_url": {"url": pdf_url}}}
client.add([pdf_message], user_id=user_id)
print("✅ PDF uploaded and processed into memory.")
# Main interaction loop with your personal study buddy
async def study_buddy(user_id: str, topic: str, user_input: str):
memories = client.search(f"{topic}", user_id=user_id)
memory_context = "n".join(f"- {m['memory']}" for m in memories)
@@ -56,9 +51,11 @@ Now respond to the user's new question or comment:
result = await Runner.run(study_agent, prompt)
response = result.final_output
client.add([
{"role": "user", "content": f'''Topic: {topic}nUser: {user_input}nnStudy Assistant: {response}'''}
], user_id=user_id, metadata={"topic": topic})
client.add(
[{"role": "user", "content": f"""Topic: {topic}nUser: {user_input}nnStudy Assistant: {response}"""}],
user_id=user_id,
metadata={"topic": topic},
)
return response
@@ -78,7 +75,12 @@ async def main():
# Demonstrate spaced repetition prompting
topic = "Momentum Conservation"
print(await study_buddy(user_id, topic, "I think we covered this last week. Is it time to review momentum conservation again?"))
print(
await study_buddy(
user_id, topic, "I think we covered this last week. Is it time to review momentum conservation again?"
)
)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -57,7 +57,7 @@ def initialize_memory():
},
{
"role": "user",
"content": "I prefer brief and concise responses without unnecessary explanations. I get frustrated when assistants are too wordy or repeat information I already know."
"content": "I prefer brief and concise responses without unnecessary explanations. I get frustrated when assistants are too wordy or repeat information I already know.",
},
{
"role": "assistant",
@@ -65,7 +65,7 @@ def initialize_memory():
},
{
"role": "user",
"content": "I like to listen to jazz music when I'm working, especially artists like Miles Davis and John Coltrane. I find it helps me focus and be more productive."
"content": "I like to listen to jazz music when I'm working, especially artists like Miles Davis and John Coltrane. I find it helps me focus and be more productive.",
},
{
"role": "assistant",
@@ -73,7 +73,7 @@ def initialize_memory():
},
{
"role": "user",
"content": "I usually wake up at 7 AM and prefer reminders for meetings 30 minutes in advance. My most productive hours are between 9 AM and noon, so I try to schedule important tasks during that time."
"content": "I usually wake up at 7 AM and prefer reminders for meetings 30 minutes in advance. My most productive hours are between 9 AM and noon, so I try to schedule important tasks during that time.",
},
{
"role": "assistant",
@@ -81,7 +81,7 @@ def initialize_memory():
},
{
"role": "user",
"content": "My favorite color is navy blue, and I prefer dark mode in all my apps. I'm allergic to peanuts, so please remind me to check ingredients when I ask about recipes or restaurants."
"content": "My favorite color is navy blue, and I prefer dark mode in all my apps. I'm allergic to peanuts, so please remind me to check ingredients when I ask about recipes or restaurants.",
},
{
"role": "assistant",
@@ -89,7 +89,7 @@ def initialize_memory():
},
{
"role": "user",
"content": "My partner's name is Jamie, and we have a golden retriever named Max who is 3 years old. My parents live in Chicago, and I try to visit them once every two months."
"content": "My partner's name is Jamie, and we have a golden retriever named Max who is 3 years old. My parents live in Chicago, and I try to visit them once every two months.",
},
{
"role": "assistant",
@@ -135,11 +135,11 @@ def record_audio(filename="input.wav", record_seconds=5):
stream.close()
p.terminate()
with wave.open(filename, 'wb') as wf:
with wave.open(filename, "wb") as wf:
wf.setnchannels(channels)
wf.setsampwidth(p.get_sample_size(fmt))
wf.setframerate(rate)
wf.writeframes(b''.join(frames))
wf.writeframes(b"".join(frames))
# ------------------ STT USING WHISPER ------------------
@@ -147,10 +147,7 @@ def transcribe_whisper(audio_path):
print("🔎 Transcribing with Whisper...")
try:
with open(audio_path, "rb") as audio_file:
transcript = openai_client.audio.transcriptions.create(
model="whisper-1",
file=audio_file
)
transcript = openai_client.audio.transcriptions.create(model="whisper-1", file=audio_file)
print(f"🗣️ You said: {transcript.text}")
return transcript.text
except Exception as e:
@@ -165,9 +162,7 @@ def get_agent_response(user_input):
try:
task = Task(
description=f"Respond to: {user_input}",
expected_output="A short and relevant reply.",
agent=voice_agent
description=f"Respond to: {user_input}", expected_output="A short and relevant reply.", agent=voice_agent
)
crew = Crew(
agents=[voice_agent],
@@ -175,22 +170,19 @@ def get_agent_response(user_input):
process=Process.sequential,
verbose=True,
memory=True,
memory_config={
"provider": "mem0",
"config": {"user_id": USER_ID}
}
memory_config={"provider": "mem0", "config": {"user_id": USER_ID}},
)
result = crew.kickoff()
# Extract the text response from the complex result object
if hasattr(result, 'raw'):
if hasattr(result, "raw"):
return result.raw
elif isinstance(result, dict) and 'raw' in result:
return result['raw']
elif isinstance(result, dict) and 'tasks_output' in result:
outputs = result['tasks_output']
elif isinstance(result, dict) and "raw" in result:
return result["raw"]
elif isinstance(result, dict) and "tasks_output" in result:
outputs = result["tasks_output"]
if outputs and isinstance(outputs, list) and len(outputs) > 0:
return outputs[0].get('raw', str(result))
return outputs[0].get("raw", str(result))
# Fallback to string representation if we can't extract the raw response
return str(result)
@@ -204,10 +196,7 @@ def get_agent_response(user_input):
def speak_response(text):
print(f"🤖 Agent: {text}")
audio = tts_client.text_to_speech.convert(
text=text,
voice_id="JBFqnCBsd6RMkjVDRZzb",
model_id="eleven_multilingual_v2",
output_format="mp3_44100_128"
text=text, voice_id="JBFqnCBsd6RMkjVDRZzb", model_id="eleven_multilingual_v2", output_format="mp3_44100_128"
)
play(audio)
@@ -220,7 +209,7 @@ def run_voice_agent():
record_audio(tmp_audio.name)
try:
user_text = transcribe_whisper(tmp_audio.name)
if user_text.lower() in ['exit', 'quit', 'stop']:
if user_text.lower() in ["exit", "quit", "stop"]:
print("👋 Exiting.")
break
response = get_agent_response(user_text)

View File

@@ -95,10 +95,7 @@ class MemoryClient:
self.client = client
# Ensure the client has the correct base_url and headers
self.client.base_url = httpx.URL(self.host)
self.client.headers.update({
"Authorization": f"Token {self.api_key}",
"Mem0-User-ID": self.user_id
})
self.client.headers.update({"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id})
else:
self.client = httpx.Client(
base_url=self.host,
@@ -237,7 +234,9 @@ class MemoryClient:
response.raise_for_status()
if "metadata" in kwargs:
del kwargs["metadata"]
capture_client_event("client.search", self, {"api_version": version, "keys": list(kwargs.keys()), "sync_type": "sync"})
capture_client_event(
"client.search", self, {"api_version": version, "keys": list(kwargs.keys()), "sync_type": "sync"}
)
return response.json()
@api_error_handler
@@ -357,10 +356,7 @@ class MemoryClient:
else:
entities = self.users()
# Filter entities based on provided IDs using list comprehension
to_delete = [
{"type": entity["type"], "name": entity["name"]}
for entity in entities["results"]
]
to_delete = [{"type": entity["type"], "name": entity["name"]} for entity in entities["results"]]
params = self._prepare_params()
@@ -373,7 +369,9 @@ class MemoryClient:
response.raise_for_status()
capture_client_event(
"client.delete_users", self, {"user_id": user_id, "agent_id": agent_id, "app_id": app_id, "run_id": run_id, "sync_type": "sync"}
"client.delete_users",
self,
{"user_id": user_id, "agent_id": agent_id, "app_id": app_id, "run_id": run_id, "sync_type": "sync"},
)
return {
"message": "Entity deleted successfully."
@@ -454,7 +452,9 @@ class MemoryClient:
"""
response = self.client.post("/v1/exports/", json={"schema": schema, **self._prepare_params(kwargs)})
response.raise_for_status()
capture_client_event("client.create_memory_export", self, {"schema": schema, "keys": list(kwargs.keys()), "sync_type": "sync"})
capture_client_event(
"client.create_memory_export", self, {"schema": schema, "keys": list(kwargs.keys()), "sync_type": "sync"}
)
return response.json()
@api_error_handler
@@ -527,7 +527,11 @@ class MemoryClient:
)
payload = self._prepare_params(
{"custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria}
{
"custom_instructions": custom_instructions,
"custom_categories": custom_categories,
"retrieval_criteria": retrieval_criteria,
}
)
response = self.client.patch(
f"/api/v1/orgs/organizations/{self.org_id}/projects/{self.project_id}/",
@@ -537,7 +541,12 @@ class MemoryClient:
capture_client_event(
"client.update_project",
self,
{"custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria, "sync_type": "sync"},
{
"custom_instructions": custom_instructions,
"custom_categories": custom_categories,
"retrieval_criteria": retrieval_criteria,
"sync_type": "sync",
},
)
return response.json()
@@ -750,10 +759,7 @@ class AsyncMemoryClient:
self.async_client = client
# Ensure the client has the correct base_url and headers
self.async_client.base_url = httpx.URL(self.host)
self.async_client.headers.update({
"Authorization": f"Token {self.api_key}",
"Mem0-User-ID": self.user_id
})
self.async_client.headers.update({"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id})
else:
self.async_client = httpx.AsyncClient(
base_url=self.host,
@@ -768,7 +774,11 @@ class AsyncMemoryClient:
"""Validate the API key by making a test request."""
try:
params = self._prepare_params()
response = requests.get(f"{self.host}/v1/ping/", headers={"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id}, params=params)
response = requests.get(
f"{self.host}/v1/ping/",
headers={"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id},
params=params,
)
data = response.json()
response.raise_for_status()
@@ -973,10 +983,7 @@ class AsyncMemoryClient:
else:
entities = await self.users()
# Filter entities based on provided IDs using list comprehension
to_delete = [
{"type": entity["type"], "name": entity["name"]}
for entity in entities["results"]
]
to_delete = [{"type": entity["type"], "name": entity["name"]} for entity in entities["results"]]
params = self._prepare_params()
@@ -988,7 +995,11 @@ class AsyncMemoryClient:
response = await self.async_client.delete(f"/v2/entities/{entity['type']}/{entity['name']}/", params=params)
response.raise_for_status()
capture_client_event("client.delete_users", self, {"user_id": user_id, "agent_id": agent_id, "app_id": app_id, "run_id": run_id, "sync_type": "async"})
capture_client_event(
"client.delete_users",
self,
{"user_id": user_id, "agent_id": agent_id, "app_id": app_id, "run_id": run_id, "sync_type": "async"},
)
return {
"message": "Entity deleted successfully."
if (user_id or agent_id or app_id or run_id)
@@ -1091,8 +1102,10 @@ class AsyncMemoryClient:
@api_error_handler
async def update_project(
self, custom_instructions: Optional[str] = None, custom_categories: Optional[List[str]] = None,
retrieval_criteria: Optional[List[Dict[str, Any]]] = None
self,
custom_instructions: Optional[str] = None,
custom_categories: Optional[List[str]] = None,
retrieval_criteria: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, Any]:
if not (self.org_id and self.project_id):
raise ValueError("org_id and project_id must be set to update instructions or categories")
@@ -1103,7 +1116,11 @@ class AsyncMemoryClient:
)
payload = self._prepare_params(
{"custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria}
{
"custom_instructions": custom_instructions,
"custom_categories": custom_categories,
"retrieval_criteria": retrieval_criteria,
}
)
response = await self.async_client.patch(
f"/api/v1/orgs/organizations/{self.org_id}/projects/{self.project_id}/",
@@ -1113,7 +1130,12 @@ class AsyncMemoryClient:
capture_client_event(
"client.update_project",
self,
{"custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria, "sync_type": "async"},
{
"custom_instructions": custom_instructions,
"custom_categories": custom_categories,
"retrieval_criteria": retrieval_criteria,
"sync_type": "async",
},
)
return response.json()
@@ -1174,4 +1196,3 @@ class AsyncMemoryClient:
response.raise_for_status()
capture_client_event("client.feedback", self, data, {"sync_type": "async"})
return response.json()

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Union, Type
from typing import Any, Dict, Optional, Type, Union
from pydantic import BaseModel, Field, model_validator
@@ -7,33 +7,17 @@ class OpenSearchConfig(BaseModel):
collection_name: str = Field("mem0", description="Name of the index")
host: str = Field("localhost", description="OpenSearch host")
port: int = Field(9200, description="OpenSearch port")
user: Optional[str] = Field(
None, description="Username for authentication"
)
password: Optional[str] = Field(
None, description="Password for authentication"
)
api_key: Optional[str] = Field(
None, description="API key for authentication (if applicable)"
)
embedding_model_dims: int = Field(
1536, description="Dimension of the embedding vector"
)
verify_certs: bool = Field(
False, description="Verify SSL certificates (default False for OpenSearch)"
)
use_ssl: bool = Field(
False, description="Use SSL for connection (default False for OpenSearch)"
)
http_auth: Optional[object] = Field(
None, description="HTTP authentication method / AWS SigV4"
)
user: Optional[str] = Field(None, description="Username for authentication")
password: Optional[str] = Field(None, description="Password for authentication")
api_key: Optional[str] = Field(None, description="API key for authentication (if applicable)")
embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
verify_certs: bool = Field(False, description="Verify SSL certificates (default False for OpenSearch)")
use_ssl: bool = Field(False, description="Use SSL for connection (default False for OpenSearch)")
http_auth: Optional[object] = Field(None, description="HTTP authentication method / AWS SigV4")
connection_class: Optional[Union[str, Type]] = Field(
"RequestsHttpConnection", description="Connection class for OpenSearch"
)
pool_maxsize: int = Field(
20, description="Maximum number of connections in the pool"
)
pool_maxsize: int = Field(20, description="Maximum number of connections in the pool")
@model_validator(mode="before")
@classmethod
@@ -41,7 +25,7 @@ class OpenSearchConfig(BaseModel):
# Check if host is provided
if not values.get("host"):
raise ValueError("Host must be provided for OpenSearch")
return values
@model_validator(mode="before")
@@ -52,7 +36,6 @@ class OpenSearchConfig(BaseModel):
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. "
f"Allowed fields: {', '.join(allowed_fields)}"
f"Extra fields not allowed: {', '.join(extra_fields)}. " f"Allowed fields: {', '.join(allowed_fields)}"
)
return values

View File

@@ -23,12 +23,12 @@ class AWSBedrockEmbedding(EmbeddingBase):
super().__init__(config)
self.config.model = self.config.model or "amazon.titan-embed-text-v1"
# Get AWS config from environment variables or use defaults
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
aws_region = os.environ.get("AWS_REGION", "us-west-2")
# Check if AWS config is provided in the config
if hasattr(self.config, "aws_access_key_id"):
aws_access_key = self.config.aws_access_key_id
@@ -36,7 +36,7 @@ class AWSBedrockEmbedding(EmbeddingBase):
aws_secret_key = self.config.aws_secret_access_key
if hasattr(self.config, "aws_region"):
aws_region = self.config.aws_region
self.client = boto3.client(
"bedrock-runtime",
region_name=aws_region,

View File

@@ -11,6 +11,7 @@ logging.getLogger("transformers").setLevel(logging.WARNING)
logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
class HuggingFaceEmbedding(EmbeddingBase):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config)

View File

@@ -22,7 +22,8 @@ class Neo4jConfig(BaseModel):
if not url or not username or not password:
raise ValueError("Please provide 'url', 'username' and 'password'.")
return values
class MemgraphConfig(BaseModel):
url: Optional[str] = Field(None, description="Host address for the graph database")
username: Optional[str] = Field(None, description="Username for the graph database")

View File

@@ -20,18 +20,19 @@ def extract_provider(model: str) -> str:
return provider
raise ValueError(f"Unknown provider in model: {model}")
class AWSBedrockLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
if not self.config.model:
self.config.model = "anthropic.claude-3-5-sonnet-20240620-v1:0"
# Get AWS config from environment variables or use defaults
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
aws_region = os.environ.get("AWS_REGION", "us-west-2")
# Check if AWS config is provided in the config
if hasattr(self.config, "aws_access_key_id"):
aws_access_key = self.config.aws_access_key_id
@@ -39,14 +40,14 @@ class AWSBedrockLLM(LLMBase):
aws_secret_key = self.config.aws_secret_access_key
if hasattr(self.config, "aws_region"):
aws_region = self.config.aws_region
self.client = boto3.client(
"bedrock-runtime",
region_name=aws_region,
aws_access_key_id=aws_access_key if aws_access_key else None,
aws_secret_access_key=aws_secret_key if aws_secret_key else None,
)
self.model_kwargs = {
"temperature": self.config.temperature,
"max_tokens_to_sample": self.config.max_tokens,
@@ -145,7 +146,9 @@ class AWSBedrockLLM(LLMBase):
input_body = {
"inputText": prompt,
"textGenerationConfig": {
"maxTokenCount": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000,
"maxTokenCount": self.model_kwargs["max_tokens_to_sample"]
or self.model_kwargs["max_tokens"]
or 5000,
"topP": self.model_kwargs["top_p"] or 0.9,
"temperature": self.model_kwargs["temperature"] or 0.1,
},
@@ -243,22 +246,15 @@ class AWSBedrockLLM(LLMBase):
body = json.dumps(input_body)
if provider == "anthropic" or provider == "deepseek":
input_body = {
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": prompt}]
}
],
"messages": [{"role": "user", "content": [{"type": "text", "text": prompt}]}],
"max_tokens": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000,
"temperature": self.model_kwargs["temperature"] or 0.1,
"top_p": self.model_kwargs["top_p"] or 0.9,
"anthropic_version": "bedrock-2023-05-31",
}
body = json.dumps(input_body)
response = self.client.invoke_model(
body=body,
@@ -272,6 +268,6 @@ class AWSBedrockLLM(LLMBase):
modelId=self.config.model,
accept="application/json",
contentType="application/json",
)
)
return self._parse_response(response, tools)

View File

@@ -34,17 +34,17 @@ from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
def _build_filters_and_metadata(
*, # Enforce keyword-only arguments
*, # Enforce keyword-only arguments
user_id: Optional[str] = None,
agent_id: Optional[str] = None,
run_id: Optional[str] = None,
actor_id: Optional[str] = None, # For query-time filtering
actor_id: Optional[str] = None, # For query-time filtering
input_metadata: Optional[Dict[str, Any]] = None,
input_filters: Optional[Dict[str, Any]] = None,
) -> tuple[Dict[str, Any], Dict[str, Any]]:
"""
Constructs metadata for storage and filters for querying based on session and actor identifiers.
This helper ties every memory/query to exactly one session id (`user_id`, `agent_id`, or `run_id`) and optionally narrows queries to a specific `actor_id`. It returns two dicts:
@@ -78,10 +78,10 @@ def _build_filters_and_metadata(
- effective_query_filters (Dict[str, Any]): Filters for querying memories,
scoped to the determined session and potentially a resolved actor.
"""
base_metadata_template = deepcopy(input_metadata) if input_metadata else {}
effective_query_filters = deepcopy(input_filters) if input_filters else {}
# ---------- resolve session id (mandatory) ----------
session_key, session_val = None, None
if user_id:
@@ -90,20 +90,20 @@ def _build_filters_and_metadata(
session_key, session_val = "agent_id", agent_id
elif run_id:
session_key, session_val = "run_id", run_id
if session_key is None:
raise ValueError("One of 'user_id', 'agent_id', or 'run_id' must be provided.")
base_metadata_template[session_key] = session_val
effective_query_filters[session_key] = session_val
# ---------- optional actor filter ----------
resolved_actor_id = actor_id or effective_query_filters.get("actor_id")
if resolved_actor_id:
effective_query_filters["actor_id"] = resolved_actor_id
return base_metadata_template, effective_query_filters
setup_config()
logger = logging.getLogger(__name__)
@@ -189,7 +189,7 @@ class Memory(MemoryBase):
):
"""
Create a new memory.
Adds new memories scoped to a single session id (e.g. `user_id`, `agent_id`, or `run_id`). One of those ids is required.
Args:
@@ -208,7 +208,7 @@ class Memory(MemoryBase):
creating procedural memories (typically requires 'agent_id'). Otherwise, memories
are treated as general conversational/factual memories.memory_type (str, optional): Type of memory to create. Defaults to None. By default, it creates the short term memories and long term (semantic and episodic) memories. Pass "procedural_memory" to create procedural memories.
prompt (str, optional): Prompt to use for the memory creation. Defaults to None.
Returns:
dict: A dictionary containing the result of the memory addition operation, typically
@@ -216,14 +216,14 @@ class Memory(MemoryBase):
and potentially "relations" if graph store is enabled.
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "event": "ADD"}]}`
"""
processed_metadata, effective_filters = _build_filters_and_metadata(
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
input_metadata=metadata,
)
if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value:
raise ValueError(
f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} to create procedural memories."
@@ -231,10 +231,10 @@ class Memory(MemoryBase):
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
elif isinstance(messages, dict):
messages = [messages]
elif not isinstance(messages, list):
raise ValueError("messages must be str, dict, or list[dict]")
@@ -255,7 +255,7 @@ class Memory(MemoryBase):
vector_store_result = future1.result()
graph_result = future2.result()
if self.api_version == "v1.0":
warnings.warn(
"The current add API output format is deprecated. "
@@ -277,21 +277,21 @@ class Memory(MemoryBase):
def _add_to_vector_store(self, messages, metadata, filters, infer):
if not infer:
returned_memories = []
for message_dict in messages:
if not isinstance(message_dict, dict) or \
message_dict.get("role") is None or \
message_dict.get("content") is None:
for message_dict in messages:
if (
not isinstance(message_dict, dict)
or message_dict.get("role") is None
or message_dict.get("content") is None
):
logger.warning(f"Skipping invalid message format: {message_dict}")
continue
if message_dict["role"] == "system":
continue
continue
per_msg_meta = deepcopy(metadata)
per_msg_meta["role"] = message_dict["role"]
actor_name = message_dict.get("name")
if actor_name:
per_msg_meta["actor_id"] = actor_name
@@ -311,8 +311,8 @@ class Memory(MemoryBase):
)
return returned_memories
parsed_messages = parse_messages(messages)
parsed_messages = parse_messages(messages)
if self.config.custom_fact_extraction_prompt:
system_prompt = self.config.custom_fact_extraction_prompt
user_prompt = f"Input:\n{parsed_messages}"
@@ -336,7 +336,7 @@ class Memory(MemoryBase):
retrieved_old_memory = []
new_message_embeddings = {}
for new_mem in new_retrieved_facts:
for new_mem in new_retrieved_facts:
messages_embeddings = self.embedding_model.embed(new_mem, "add")
new_message_embeddings[new_mem] = messages_embeddings
existing_memories = self.vector_store.search(
@@ -347,7 +347,7 @@ class Memory(MemoryBase):
)
for mem in existing_memories:
retrieved_old_memory.append({"id": mem.id, "text": mem.payload["data"]})
unique_data = {}
for item in retrieved_old_memory:
unique_data[item["id"]] = item
@@ -389,7 +389,7 @@ class Memory(MemoryBase):
if not action_text:
logging.info("Skipping memory entry because of empty `text` field.")
continue
event_type = resp.get("event")
if event_type == "ADD":
memory_id = self._create_memory(
@@ -405,16 +405,23 @@ class Memory(MemoryBase):
existing_embeddings=new_message_embeddings,
metadata=deepcopy(metadata),
)
returned_memories.append({
"id": temp_uuid_mapping[resp.get("id")], "memory": action_text,
"event": event_type, "previous_memory": resp.get("old_memory"),
})
returned_memories.append(
{
"id": temp_uuid_mapping[resp.get("id")],
"memory": action_text,
"event": event_type,
"previous_memory": resp.get("old_memory"),
}
)
elif event_type == "DELETE":
self._delete_memory(memory_id=temp_uuid_mapping[resp.get("id")])
returned_memories.append({
"id": temp_uuid_mapping[resp.get("id")], "memory": action_text,
"event": event_type,
})
returned_memories.append(
{
"id": temp_uuid_mapping[resp.get("id")],
"memory": action_text,
"event": event_type,
}
)
elif event_type == "NONE":
logging.info("NOOP for Memory.")
except Exception as e:
@@ -462,11 +469,8 @@ class Memory(MemoryBase):
"actor_id",
"role",
]
core_and_promoted_keys = {
"data", "hash", "created_at", "updated_at", "id",
*promoted_payload_keys
}
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
result_item = MemoryItem(
id=memory.id,
@@ -479,18 +483,16 @@ class Memory(MemoryBase):
for key in promoted_payload_keys:
if key in memory.payload:
result_item[key] = memory.payload[key]
additional_metadata = {
k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys
}
additional_metadata = {k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys}
if additional_metadata:
result_item["metadata"] = additional_metadata
return result_item
def get_all(
self,
*,
*,
user_id: Optional[str] = None,
agent_id: Optional[str] = None,
run_id: Optional[str] = None,
@@ -505,7 +507,7 @@ class Memory(MemoryBase):
agent_id (str, optional): agent id
run_id (str, optional): run id
filters (dict, optional): Additional custom key-value filters to apply to the search.
These are merged with the ID-based scoping filters. For example,
These are merged with the ID-based scoping filters. For example,
`filters={"actor_id": "some_user"}`.
limit (int, optional): The maximum number of memories to return. Defaults to 100.
@@ -515,21 +517,16 @@ class Memory(MemoryBase):
it might return a direct list (see deprecation warning).
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", ...}]}`
"""
_, effective_filters = _build_filters_and_metadata(
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
input_filters=filters
user_id=user_id, agent_id=agent_id, run_id=run_id, input_filters=filters
)
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError("At least one of 'user_id', 'agent_id', or 'run_id' must be specified.")
capture_event(
"mem0.get_all",
self,
{"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "sync"}
"mem0.get_all", self, {"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "sync"}
)
with concurrent.futures.ThreadPoolExecutor() as executor:
@@ -542,9 +539,9 @@ class Memory(MemoryBase):
[future_memories, future_graph_entities] if future_graph_entities else [future_memories]
)
all_memories_result = future_memories.result()
all_memories_result = future_memories.result()
graph_entities_result = future_graph_entities.result() if future_graph_entities else None
if self.enable_graph:
return {"results": all_memories_result, "relations": graph_entities_result}
@@ -556,26 +553,27 @@ class Memory(MemoryBase):
category=DeprecationWarning,
stacklevel=2,
)
return all_memories_result
return all_memories_result
else:
return {"results": all_memories_result}
def _get_all_from_vector_store(self, filters, limit):
memories_result = self.vector_store.list(filters=filters, limit=limit)
actual_memories = memories_result[0] if isinstance(memories_result, tuple) and len(memories_result) > 0 else memories_result
actual_memories = (
memories_result[0] if isinstance(memories_result, tuple) and len(memories_result) > 0 else memories_result
)
promoted_payload_keys = [
"user_id", "agent_id", "run_id",
"user_id",
"agent_id",
"run_id",
"actor_id",
"role",
]
core_and_promoted_keys = {
"data", "hash", "created_at", "updated_at", "id",
*promoted_payload_keys
}
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
formatted_memories = []
for mem in actual_memories:
for mem in actual_memories:
memory_item_dict = MemoryItem(
id=mem.id,
memory=mem.payload["data"],
@@ -587,15 +585,13 @@ class Memory(MemoryBase):
for key in promoted_payload_keys:
if key in mem.payload:
memory_item_dict[key] = mem.payload[key]
additional_metadata = {
k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys
}
additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys}
if additional_metadata:
memory_item_dict["metadata"] = additional_metadata
formatted_memories.append(memory_item_dict)
return formatted_memories
def search(
@@ -624,12 +620,9 @@ class Memory(MemoryBase):
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "score": 0.8, ...}]}`
"""
_, effective_filters = _build_filters_and_metadata(
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
input_filters=filters
user_id=user_id, agent_id=agent_id, run_id=run_id, input_filters=filters
)
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError("At least one of 'user_id', 'agent_id', or 'run_id' must be specified.")
@@ -651,7 +644,7 @@ class Memory(MemoryBase):
original_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None
if self.enable_graph:
return {"results": original_memories, "relations": graph_entities}
@@ -678,11 +671,8 @@ class Memory(MemoryBase):
"actor_id",
"role",
]
core_and_promoted_keys = {
"data", "hash", "created_at", "updated_at", "id",
*promoted_payload_keys
}
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
original_memories = []
for mem in memories:
@@ -693,18 +683,16 @@ class Memory(MemoryBase):
created_at=mem.payload.get("created_at"),
updated_at=mem.payload.get("updated_at"),
score=mem.score,
).model_dump()
).model_dump()
for key in promoted_payload_keys:
if key in mem.payload:
memory_item_dict[key] = mem.payload[key]
additional_metadata = {
k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys
}
additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys}
if additional_metadata:
memory_item_dict["metadata"] = additional_metadata
original_memories.append(memory_item_dict)
return original_memories
@@ -738,7 +726,7 @@ class Memory(MemoryBase):
self._delete_memory(memory_id)
return {"message": "Memory deleted successfully!"}
def delete_all(self, user_id:Optional[str]=None, agent_id:Optional[str]=None, run_id:Optional[str]=None):
def delete_all(self, user_id: Optional[str] = None, agent_id: Optional[str] = None, run_id: Optional[str] = None):
"""
Delete all memories.
@@ -860,11 +848,11 @@ class Memory(MemoryBase):
except Exception:
logger.error(f"Error getting memory with ID {memory_id} during update.")
raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'")
prev_value = existing_memory.payload.get("data")
new_metadata = deepcopy(metadata) if metadata is not None else {}
new_metadata["data"] = data
new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
new_metadata["created_at"] = existing_memory.payload.get("created_at")
@@ -875,7 +863,7 @@ class Memory(MemoryBase):
if "agent_id" in existing_memory.payload:
new_metadata["agent_id"] = existing_memory.payload["agent_id"]
if "run_id" in existing_memory.payload:
new_metadata["run_id"] = existing_memory.payload["run_id"]
new_metadata["run_id"] = existing_memory.payload["run_id"]
if "actor_id" in existing_memory.payload:
new_metadata["actor_id"] = existing_memory.payload["actor_id"]
if "role" in existing_memory.payload:
@@ -885,14 +873,14 @@ class Memory(MemoryBase):
embeddings = existing_embeddings[data]
else:
embeddings = self.embedding_model.embed(data, "update")
self.vector_store.update(
vector_id=memory_id,
vector=embeddings,
payload=new_metadata,
)
logger.info(f"Updating memory with ID {memory_id=} with {data=}")
self.db.add_history(
memory_id,
prev_value,
@@ -1037,12 +1025,9 @@ class AsyncMemory(MemoryBase):
dict: A dictionary containing the result of the memory addition operation.
"""
processed_metadata, effective_filters = _build_filters_and_metadata(
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
input_metadata=metadata
user_id=user_id, agent_id=agent_id, run_id=run_id, input_metadata=metadata
)
if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value:
raise ValueError(
f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} to create procedural memories."
@@ -1050,15 +1035,17 @@ class AsyncMemory(MemoryBase):
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
elif isinstance(messages, dict):
messages = [messages]
elif not isinstance(messages, list):
raise ValueError("messages must be str, dict, or list[dict]")
if agent_id is not None and memory_type == MemoryType.PROCEDURAL.value:
results = await self._create_procedural_memory(messages, metadata=processed_metadata, prompt=prompt, llm=llm)
results = await self._create_procedural_memory(
messages, metadata=processed_metadata, prompt=prompt, llm=llm
)
return results
if self.config.llm.config.get("enable_vision"):
@@ -1066,7 +1053,9 @@ class AsyncMemory(MemoryBase):
else:
messages = parse_vision_messages(messages)
vector_store_task = asyncio.create_task(self._add_to_vector_store(messages, processed_metadata, effective_filters, infer))
vector_store_task = asyncio.create_task(
self._add_to_vector_store(messages, processed_metadata, effective_filters, infer)
)
graph_task = asyncio.create_task(self._add_to_graph(messages, effective_filters))
vector_store_result, graph_result = await asyncio.gather(vector_store_task, graph_task)
@@ -1090,8 +1079,8 @@ class AsyncMemory(MemoryBase):
return {"results": vector_store_result}
async def _add_to_vector_store(
self,
messages: list,
self,
messages: list,
metadata: dict,
filters: dict,
infer: bool,
@@ -1099,9 +1088,11 @@ class AsyncMemory(MemoryBase):
if not infer:
returned_memories = []
for message_dict in messages:
if not isinstance(message_dict, dict) or \
message_dict.get("role") is None or \
message_dict.get("content") is None:
if (
not isinstance(message_dict, dict)
or message_dict.get("role") is None
or message_dict.get("content") is None
):
logger.warning(f"Skipping invalid message format (async): {message_dict}")
continue
@@ -1110,20 +1101,24 @@ class AsyncMemory(MemoryBase):
per_msg_meta = deepcopy(metadata)
per_msg_meta["role"] = message_dict["role"]
actor_name = message_dict.get("name")
if actor_name:
per_msg_meta["actor_id"] = actor_name
msg_content = message_dict["content"]
msg_embeddings = await asyncio.to_thread(self.embedding_model.embed, msg_content, "add")
mem_id = await self._create_memory(msg_content, msg_embeddings, per_msg_meta)
returned_memories.append({
"id": mem_id, "memory": msg_content, "event": "ADD",
"actor_id": actor_name if actor_name else None,
"role": message_dict["role"],
})
returned_memories.append(
{
"id": mem_id,
"memory": msg_content,
"event": "ADD",
"actor_id": actor_name if actor_name else None,
"role": message_dict["role"],
}
)
return returned_memories
parsed_messages = parse_messages(messages)
@@ -1142,17 +1137,21 @@ class AsyncMemory(MemoryBase):
response = remove_code_blocks(response)
new_retrieved_facts = json.loads(response)["facts"]
except Exception as e:
logging.error(f"Error in new_retrieved_facts: {e}"); new_retrieved_facts = []
logging.error(f"Error in new_retrieved_facts: {e}")
new_retrieved_facts = []
retrieved_old_memory = []
new_message_embeddings = {}
async def process_fact_for_search(new_mem_content):
embeddings = await asyncio.to_thread(self.embedding_model.embed, new_mem_content, "add")
new_message_embeddings[new_mem_content] = embeddings
existing_mems = await asyncio.to_thread(
self.vector_store.search, query=new_mem_content, vectors=embeddings,
limit=5, filters=filters, # 'filters' is query_filters_for_inference
self.vector_store.search,
query=new_mem_content,
vectors=embeddings,
limit=5,
filters=filters, # 'filters' is query_filters_for_inference
)
return [{"id": mem.id, "text": mem.payload["data"]} for mem in existing_mems]
@@ -1160,9 +1159,10 @@ class AsyncMemory(MemoryBase):
search_results_list = await asyncio.gather(*search_tasks)
for result_group in search_results_list:
retrieved_old_memory.extend(result_group)
unique_data = {}
for item in retrieved_old_memory: unique_data[item["id"]] = item
for item in retrieved_old_memory:
unique_data[item["id"]] = item
retrieved_old_memory = list(unique_data.values())
logging.info(f"Total existing memories: {len(retrieved_old_memory)}")
temp_uuid_mapping = {}
@@ -1180,35 +1180,45 @@ class AsyncMemory(MemoryBase):
response_format={"type": "json_object"},
)
except Exception as e:
logging.error(f"Error in new memory actions response: {e}"); response = ""
logging.error(f"Error in new memory actions response: {e}")
response = ""
try:
response = remove_code_blocks(response)
new_memories_with_actions = json.loads(response)
except Exception as e:
logging.error(f"Invalid JSON response: {e}"); new_memories_with_actions = {}
logging.error(f"Invalid JSON response: {e}")
new_memories_with_actions = {}
returned_memories = []
returned_memories = []
try:
memory_tasks = []
for resp in new_memories_with_actions.get("memory", []):
logging.info(resp)
try:
action_text = resp.get("text")
if not action_text: continue
if not action_text:
continue
event_type = resp.get("event")
if event_type == "ADD":
task = asyncio.create_task(self._create_memory(
data=action_text, existing_embeddings=new_message_embeddings,
metadata=deepcopy(metadata)
))
task = asyncio.create_task(
self._create_memory(
data=action_text,
existing_embeddings=new_message_embeddings,
metadata=deepcopy(metadata),
)
)
memory_tasks.append((task, resp, "ADD", None))
elif event_type == "UPDATE":
task = asyncio.create_task(self._update_memory(
memory_id=temp_uuid_mapping[resp["id"]], data=action_text,
existing_embeddings=new_message_embeddings, metadata=deepcopy(metadata)
))
task = asyncio.create_task(
self._update_memory(
memory_id=temp_uuid_mapping[resp["id"]],
data=action_text,
existing_embeddings=new_message_embeddings,
metadata=deepcopy(metadata),
)
)
memory_tasks.append((task, resp, "UPDATE", temp_uuid_mapping[resp["id"]]))
elif event_type == "DELETE":
task = asyncio.create_task(self._delete_memory(memory_id=temp_uuid_mapping[resp.get("id")]))
@@ -1217,31 +1227,30 @@ class AsyncMemory(MemoryBase):
logging.info("NOOP for Memory (async).")
except Exception as e:
logging.error(f"Error processing memory action (async): {resp}, Error: {e}")
for task, resp, event_type, mem_id in memory_tasks:
try:
result_id = await task
if event_type == "ADD":
returned_memories.append({
"id": result_id, "memory": resp.get("text"), "event": event_type
})
returned_memories.append({"id": result_id, "memory": resp.get("text"), "event": event_type})
elif event_type == "UPDATE":
returned_memories.append({
"id": mem_id, "memory": resp.get("text"),
"event": event_type, "previous_memory": resp.get("old_memory")
})
returned_memories.append(
{
"id": mem_id,
"memory": resp.get("text"),
"event": event_type,
"previous_memory": resp.get("old_memory"),
}
)
elif event_type == "DELETE":
returned_memories.append({
"id": mem_id, "memory": resp.get("text"), "event": event_type
})
returned_memories.append({"id": mem_id, "memory": resp.get("text"), "event": event_type})
except Exception as e:
logging.error(f"Error awaiting memory task (async): {e}")
except Exception as e:
logging.error(f"Error in memory processing loop (async): {e}")
capture_event(
"mem0.add", self,
{"version": self.api_version, "keys": list(filters.keys()), "sync_type": "async"}
"mem0.add", self, {"version": self.api_version, "keys": list(filters.keys()), "sync_type": "async"}
)
return returned_memories
@@ -1272,17 +1281,14 @@ class AsyncMemory(MemoryBase):
return None
promoted_payload_keys = [
"user_id",
"agent_id",
"run_id",
"user_id",
"agent_id",
"run_id",
"actor_id",
"role",
]
core_and_promoted_keys = {
"data", "hash", "created_at", "updated_at", "id",
*promoted_payload_keys
}
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
result_item = MemoryItem(
id=memory.id,
@@ -1295,18 +1301,16 @@ class AsyncMemory(MemoryBase):
for key in promoted_payload_keys:
if key in memory.payload:
result_item[key] = memory.payload[key]
additional_metadata = {
k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys
}
additional_metadata = {k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys}
if additional_metadata:
result_item["metadata"] = additional_metadata
return result_item
async def get_all(
self,
*,
*,
user_id: Optional[str] = None,
agent_id: Optional[str] = None,
run_id: Optional[str] = None,
@@ -1314,41 +1318,36 @@ class AsyncMemory(MemoryBase):
limit: int = 100,
):
"""
List all memories.
List all memories.
Args:
user_id (str, optional): user id
agent_id (str, optional): agent id
run_id (str, optional): run id
filters (dict, optional): Additional custom key-value filters to apply to the search.
These are merged with the ID-based scoping filters. For example,
`filters={"actor_id": "some_user"}`.
limit (int, optional): The maximum number of memories to return. Defaults to 100.
Args:
user_id (str, optional): user id
agent_id (str, optional): agent id
run_id (str, optional): run id
filters (dict, optional): Additional custom key-value filters to apply to the search.
These are merged with the ID-based scoping filters. For example,
`filters={"actor_id": "some_user"}`.
limit (int, optional): The maximum number of memories to return. Defaults to 100.
Returns:
dict: A dictionary containing a list of memories under the "results" key,
and potentially "relations" if graph store is enabled. For API v1.0,
it might return a direct list (see deprecation warning).
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", ...}]}`
Returns:
dict: A dictionary containing a list of memories under the "results" key,
and potentially "relations" if graph store is enabled. For API v1.0,
it might return a direct list (see deprecation warning).
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", ...}]}`
"""
_, effective_filters = _build_filters_and_metadata(
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
input_filters=filters
user_id=user_id, agent_id=agent_id, run_id=run_id, input_filters=filters
)
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError(
"When 'conversation_id' is not provided (classic mode), "
"at least one of 'user_id', 'agent_id', or 'run_id' must be specified for get_all."
)
raise ValueError(
"When 'conversation_id' is not provided (classic mode), "
"at least one of 'user_id', 'agent_id', or 'run_id' must be specified for get_all."
)
capture_event(
"mem0.get_all",
self,
{"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "async"}
"mem0.get_all", self, {"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "async"}
)
with concurrent.futures.ThreadPoolExecutor() as executor:
@@ -1361,9 +1360,9 @@ class AsyncMemory(MemoryBase):
[future_memories, future_graph_entities] if future_graph_entities else [future_memories]
)
all_memories_result = future_memories.result()
all_memories_result = future_memories.result()
graph_entities_result = future_graph_entities.result() if future_graph_entities else None
if self.enable_graph:
return {"results": all_memories_result, "relations": graph_entities_result}
@@ -1381,20 +1380,21 @@ class AsyncMemory(MemoryBase):
async def _get_all_from_vector_store(self, filters, limit):
memories_result = await asyncio.to_thread(self.vector_store.list, filters=filters, limit=limit)
actual_memories = memories_result[0] if isinstance(memories_result, tuple) and len(memories_result) > 0 else memories_result
actual_memories = (
memories_result[0] if isinstance(memories_result, tuple) and len(memories_result) > 0 else memories_result
)
promoted_payload_keys = [
"user_id", "agent_id", "run_id",
"user_id",
"agent_id",
"run_id",
"actor_id",
"role",
]
core_and_promoted_keys = {
"data", "hash", "created_at", "updated_at", "id",
*promoted_payload_keys
}
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
formatted_memories = []
for mem in actual_memories:
for mem in actual_memories:
memory_item_dict = MemoryItem(
id=mem.id,
memory=mem.payload["data"],
@@ -1406,15 +1406,13 @@ class AsyncMemory(MemoryBase):
for key in promoted_payload_keys:
if key in mem.payload:
memory_item_dict[key] = mem.payload[key]
additional_metadata = {
k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys
}
additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys}
if additional_metadata:
memory_item_dict["metadata"] = additional_metadata
formatted_memories.append(memory_item_dict)
return formatted_memories
async def search(
@@ -1442,16 +1440,13 @@ class AsyncMemory(MemoryBase):
and potentially "relations" if graph store is enabled.
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "score": 0.8, ...}]}`
"""
_, effective_filters = _build_filters_and_metadata(
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
input_filters=filters
user_id=user_id, agent_id=agent_id, run_id=run_id, input_filters=filters
)
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError("at least one of 'user_id', 'agent_id', or 'run_id' must be specified ")
raise ValueError("at least one of 'user_id', 'agent_id', or 'run_id' must be specified ")
capture_event(
"mem0.search",
@@ -1460,22 +1455,20 @@ class AsyncMemory(MemoryBase):
)
vector_store_task = asyncio.create_task(self._search_vector_store(query, effective_filters, limit))
graph_task = None
if self.enable_graph:
if hasattr(self.graph.search, "__await__"): # Check if graph search is async
graph_task = asyncio.create_task(self.graph.search(query, effective_filters, limit))
else:
graph_task = asyncio.create_task(
asyncio.to_thread(self.graph.search, query, effective_filters, limit)
)
graph_task = asyncio.create_task(asyncio.to_thread(self.graph.search, query, effective_filters, limit))
if graph_task:
original_memories, graph_entities = await asyncio.gather(vector_store_task, graph_task)
else:
original_memories = await vector_store_task
graph_entities = None
if self.enable_graph:
return {"results": original_memories, "relations": graph_entities}
@@ -1504,11 +1497,8 @@ class AsyncMemory(MemoryBase):
"actor_id",
"role",
]
core_and_promoted_keys = {
"data", "hash", "created_at", "updated_at", "id",
*promoted_payload_keys
}
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
original_memories = []
for mem in memories:
@@ -1518,19 +1508,17 @@ class AsyncMemory(MemoryBase):
hash=mem.payload.get("hash"),
created_at=mem.payload.get("created_at"),
updated_at=mem.payload.get("updated_at"),
score=mem.score,
).model_dump()
score=mem.score,
).model_dump()
for key in promoted_payload_keys:
if key in mem.payload:
memory_item_dict[key] = mem.payload[key]
additional_metadata = {
k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys
}
additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys}
if additional_metadata:
memory_item_dict["metadata"] = additional_metadata
original_memories.append(memory_item_dict)
return original_memories
@@ -1650,7 +1638,7 @@ class AsyncMemory(MemoryBase):
capture_event("mem0._create_memory", self, {"memory_id": memory_id, "sync_type": "async"})
return memory_id
async def _create_procedural_memory(self, messages, metadata=None,llm=None ,prompt=None):
async def _create_procedural_memory(self, messages, metadata=None, llm=None, prompt=None):
"""
Create a procedural memory asynchronously
@@ -1709,11 +1697,11 @@ class AsyncMemory(MemoryBase):
except Exception:
logger.error(f"Error getting memory with ID {memory_id} during update.")
raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'")
prev_value = existing_memory.payload.get("data")
new_metadata = deepcopy(metadata) if metadata is not None else {}
new_metadata["data"] = data
new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
new_metadata["created_at"] = existing_memory.payload.get("created_at")
@@ -1725,8 +1713,7 @@ class AsyncMemory(MemoryBase):
new_metadata["agent_id"] = existing_memory.payload["agent_id"]
if "run_id" in existing_memory.payload:
new_metadata["run_id"] = existing_memory.payload["run_id"]
if "actor_id" in existing_memory.payload:
new_metadata["actor_id"] = existing_memory.payload["actor_id"]
if "role" in existing_memory.payload:
@@ -1736,7 +1723,7 @@ class AsyncMemory(MemoryBase):
embeddings = existing_embeddings[data]
else:
embeddings = await asyncio.to_thread(self.embedding_model.embed, data, "update")
await asyncio.to_thread(
self.vector_store.update,
vector_id=memory_id,
@@ -1744,7 +1731,7 @@ class AsyncMemory(MemoryBase):
payload=new_metadata,
)
logger.info(f"Updating memory with ID {memory_id=} with {data=}")
await asyncio.to_thread(
self.db.add_history,
memory_id,

View File

@@ -5,16 +5,12 @@ from mem0.memory.utils import format_entities
try:
from langchain_memgraph import Memgraph
except ImportError:
raise ImportError(
"langchain_memgraph is not installed. Please install it using pip install langchain-memgraph"
)
raise ImportError("langchain_memgraph is not installed. Please install it using pip install langchain-memgraph")
try:
from rank_bm25 import BM25Okapi
except ImportError:
raise ImportError(
"rank_bm25 is not installed. Please install it using pip install rank-bm25"
)
raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
from mem0.graphs.tools import (
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
@@ -74,22 +70,14 @@ class MemoryGraph:
filters (dict): A dictionary containing filters to be applied during the addition.
"""
entity_type_map = self._retrieve_nodes_from_data(data, filters)
to_be_added = self._establish_nodes_relations_from_data(
data, filters, entity_type_map
)
search_output = self._search_graph_db(
node_list=list(entity_type_map.keys()), filters=filters
)
to_be_deleted = self._get_delete_entities_from_search_output(
search_output, data, filters
)
to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map)
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters)
# TODO: Batch queries with APOC plugin
# TODO: Add more filter support
deleted_entities = self._delete_entities(to_be_deleted, filters["user_id"])
added_entities = self._add_entities(
to_be_added, filters["user_id"], entity_type_map
)
added_entities = self._add_entities(to_be_added, filters["user_id"], entity_type_map)
return {"deleted_entities": deleted_entities, "added_entities": added_entities}
@@ -108,16 +96,13 @@ class MemoryGraph:
- "entities": List of related graph data based on the query.
"""
entity_type_map = self._retrieve_nodes_from_data(query, filters)
search_output = self._search_graph_db(
node_list=list(entity_type_map.keys()), filters=filters
)
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
if not search_output:
return []
search_outputs_sequence = [
[item["source"], item["relationship"], item["destination"]]
for item in search_output
[item["source"], item["relationship"], item["destination"]] for item in search_output
]
bm25 = BM25Okapi(search_outputs_sequence)
@@ -126,9 +111,7 @@ class MemoryGraph:
search_results = []
for item in reranked_results:
search_results.append(
{"source": item[0], "relationship": item[1], "destination": item[2]}
)
search_results.append({"source": item[0], "relationship": item[1], "destination": item[2]})
logger.info(f"Returned {len(search_results)} search results")
@@ -161,9 +144,7 @@ class MemoryGraph:
RETURN n.name AS source, type(r) AS relationship, m.name AS target
LIMIT $limit
"""
results = self.graph.query(
query, params={"user_id": filters["user_id"], "limit": limit}
)
results = self.graph.query(query, params={"user_id": filters["user_id"], "limit": limit})
final_results = []
for result in results:
@@ -208,13 +189,8 @@ class MemoryGraph:
f"Error in search tool: {e}, llm_provider={self.llm_provider}, search_results={search_results}"
)
entity_type_map = {
k.lower().replace(" ", "_"): v.lower().replace(" ", "_")
for k, v in entity_type_map.items()
}
logger.debug(
f"Entity type map: {entity_type_map}\n search_results={search_results}"
)
entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()}
logger.debug(f"Entity type map: {entity_type_map}\n search_results={search_results}")
return entity_type_map
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
@@ -223,9 +199,7 @@ class MemoryGraph:
messages = [
{
"role": "system",
"content": EXTRACT_RELATIONS_PROMPT.replace(
"USER_ID", filters["user_id"]
).replace(
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]).replace(
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
),
},
@@ -235,9 +209,7 @@ class MemoryGraph:
messages = [
{
"role": "system",
"content": EXTRACT_RELATIONS_PROMPT.replace(
"USER_ID", filters["user_id"]
),
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]),
},
{
"role": "user",
@@ -304,9 +276,7 @@ class MemoryGraph:
def _get_delete_entities_from_search_output(self, search_output, data, filters):
"""Get the entities to be deleted from the search output."""
search_output_string = format_entities(search_output)
system_prompt, user_prompt = get_delete_messages(
search_output_string, data, filters["user_id"]
)
system_prompt, user_prompt = get_delete_messages(search_output_string, data, filters["user_id"])
_tools = [DELETE_MEMORY_TOOL_GRAPH]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
@@ -379,12 +349,8 @@ class MemoryGraph:
# search for the nodes with the closest embeddings; this is basically
# comparison of one embedding to all embeddings in a graph -> vector
# search with cosine similarity metric
source_node_search_result = self._search_source_node(
source_embedding, user_id, threshold=0.9
)
destination_node_search_result = self._search_destination_node(
dest_embedding, user_id, threshold=0.9
)
source_node_search_result = self._search_source_node(source_embedding, user_id, threshold=0.9)
destination_node_search_result = self._search_destination_node(dest_embedding, user_id, threshold=0.9)
# TODO: Create a cypher query and common params for all the cases
if not destination_node_search_result and source_node_search_result:
@@ -424,9 +390,7 @@ class MemoryGraph:
"""
params = {
"destination_id": destination_node_search_result[0][
"id(destination_candidate)"
],
"destination_id": destination_node_search_result[0]["id(destination_candidate)"],
"source_name": source,
"source_embedding": source_embedding,
"user_id": user_id,
@@ -445,9 +409,7 @@ class MemoryGraph:
"""
params = {
"source_id": source_node_search_result[0]["id(source_candidate)"],
"destination_id": destination_node_search_result[0][
"id(destination_candidate)"
],
"destination_id": destination_node_search_result[0]["id(destination_candidate)"],
"user_id": user_id,
}
else:

View File

@@ -1,8 +1,8 @@
import logging
import sqlite3
import threading
import uuid
import logging
from typing import List, Dict, Any, Optional
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
@@ -23,9 +23,7 @@ class SQLiteManager:
"""
with self._lock, self.connection:
cur = self.connection.cursor()
cur.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='history'"
)
cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='history'")
if cur.fetchone() is None:
return # nothing to migrate
@@ -51,13 +49,11 @@ class SQLiteManager:
logger.info("Migrating history table to new schema (no convo columns).")
cur.execute("ALTER TABLE history RENAME TO history_old")
self._create_history_table()
self._create_history_table()
intersecting = list(expected_cols & old_cols)
cols_csv = ", ".join(intersecting)
cur.execute(
f"INSERT INTO history ({cols_csv}) SELECT {cols_csv} FROM history_old"
)
cur.execute(f"INSERT INTO history ({cols_csv}) SELECT {cols_csv} FROM history_old")
cur.execute("DROP TABLE history_old")
def _create_history_table(self) -> None:

View File

@@ -9,8 +9,8 @@ import mem0
from mem0.memory.setup import get_or_create_user_id
MEM0_TELEMETRY = os.environ.get("MEM0_TELEMETRY", "True")
PROJECT_API_KEY="phc_hgJkUVJFYtmaJqrvf6CYN67TIQ8yhXAkWzUn9AMU4yX"
HOST="https://us.i.posthog.com"
PROJECT_API_KEY = "phc_hgJkUVJFYtmaJqrvf6CYN67TIQ8yhXAkWzUn9AMU4yX"
HOST = "https://us.i.posthog.com"
if isinstance(MEM0_TELEMETRY, str):
MEM0_TELEMETRY = MEM0_TELEMETRY.lower() in ("true", "1", "yes")

View File

@@ -98,9 +98,8 @@ class VectorStoreFactory:
return vector_store_instance(**config)
else:
raise ValueError(f"Unsupported VectorStore provider: {provider_name}")
@classmethod
def reset(cls, instance):
instance.reset()
return instance

View File

@@ -377,4 +377,3 @@ class AzureAISearch(VectorStoreBase):
except Exception as e:
logger.error(f"Error resetting index {self.index_name}: {e}")
raise

View File

@@ -51,7 +51,7 @@ class VectorStoreBase(ABC):
def list(self, filters=None, limit=None):
"""List all memories."""
pass
@abstractmethod
def reset(self):
"""Reset by delete the collection and recreate it."""

View File

@@ -221,7 +221,7 @@ class ChromaDB(VectorStoreBase):
"""
results = self.collection.get(where=filters, limit=limit)
return [self._parse_output(results)]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")

View File

@@ -58,7 +58,12 @@ class ElasticsearchDB(VectorStoreBase):
"mappings": {
"properties": {
"text": {"type": "text"},
"vector": {"type": "dense_vector", "dims": self.embedding_model_dims, "index": True, "similarity": "cosine"},
"vector": {
"type": "dense_vector",
"dims": self.embedding_model_dims,
"index": True,
"similarity": "cosine",
},
"metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}},
}
},
@@ -222,7 +227,7 @@ class ElasticsearchDB(VectorStoreBase):
)
return [results]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")

View File

@@ -465,7 +465,7 @@ class FAISS(VectorStoreBase):
break
return [results]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")

View File

@@ -14,6 +14,7 @@ from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str] # memory id
score: Optional[float] # distance
@@ -162,10 +163,7 @@ class Langchain(VectorStoreBase):
if filters and "user_id" in filters:
where_clause = {"user_id": filters["user_id"]}
result = self.client._collection.get(
where=where_clause,
limit=limit
)
result = self.client._collection.get(where=where_clause, limit=limit)
# Convert the result to the expected format
if result and isinstance(result, dict):

View File

@@ -237,7 +237,7 @@ class MilvusDB(VectorStoreBase):
obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata"))
memories.append(obj)
return [memories]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")

View File

@@ -1,6 +1,6 @@
import logging
from typing import Any, Dict, List, Optional
import time
from typing import Any, Dict, List, Optional
try:
from opensearchpy import OpenSearch, RequestsHttpConnection
@@ -34,7 +34,7 @@ class OpenSearchDB(VectorStoreBase):
use_ssl=config.use_ssl,
verify_certs=config.verify_certs,
connection_class=RequestsHttpConnection,
pool_maxsize=20
pool_maxsize=20,
)
self.collection_name = config.collection_name
@@ -69,9 +69,7 @@ class OpenSearchDB(VectorStoreBase):
def create_col(self, name: str, vector_size: int) -> None:
"""Create a new collection (index in OpenSearch)."""
index_settings = {
"settings": {
"index.knn": True
},
"settings": {"index.knn": True},
"mappings": {
"properties": {
"vector_field": {
@@ -82,7 +80,7 @@ class OpenSearchDB(VectorStoreBase):
"payload": {"type": "object"},
"id": {"type": "keyword"},
}
}
},
}
if not self.client.indices.exists(index=name):
@@ -102,9 +100,7 @@ class OpenSearchDB(VectorStoreBase):
except Exception:
retry_count += 1
if retry_count == max_retries:
raise TimeoutError(
f"Index {name} creation timed out after {max_retries} seconds"
)
raise TimeoutError(f"Index {name} creation timed out after {max_retries} seconds")
time.sleep(0.5)
def insert(
@@ -145,10 +141,7 @@ class OpenSearchDB(VectorStoreBase):
}
# Start building the full query
query_body = {
"size": limit * 2,
"query": None
}
query_body = {"size": limit * 2, "query": None}
# Prepare filter conditions if applicable
filter_clauses = []
@@ -156,18 +149,11 @@ class OpenSearchDB(VectorStoreBase):
for key in ["user_id", "run_id", "agent_id"]:
value = filters.get(key)
if value:
filter_clauses.append({
"term": {f"payload.{key}.keyword": value}
})
filter_clauses.append({"term": {f"payload.{key}.keyword": value}})
# Combine knn with filters if needed
if filter_clauses:
query_body["query"] = {
"bool": {
"must": knn_query,
"filter": filter_clauses
}
}
query_body["query"] = {"bool": {"must": knn_query, "filter": filter_clauses}}
else:
query_body["query"] = knn_query
@@ -176,11 +162,7 @@ class OpenSearchDB(VectorStoreBase):
hits = response["hits"]["hits"]
results = [
OutputData(
id=hit["_source"].get("id"),
score=hit["_score"],
payload=hit["_source"].get("payload", {})
)
OutputData(id=hit["_source"].get("id"), score=hit["_score"], payload=hit["_source"].get("payload", {}))
for hit in hits
]
return results
@@ -188,13 +170,7 @@ class OpenSearchDB(VectorStoreBase):
def delete(self, vector_id: str) -> None:
"""Delete a vector by custom ID."""
# First, find the document by custom ID
search_query = {
"query": {
"term": {
"id": vector_id
}
}
}
search_query = {"query": {"term": {"id": vector_id}}}
response = self.client.search(index=self.collection_name, body=search_query)
hits = response.get("hits", {}).get("hits", [])
@@ -207,18 +183,11 @@ class OpenSearchDB(VectorStoreBase):
# Delete using the actual document ID
self.client.delete(index=self.collection_name, id=opensearch_id)
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None:
"""Update a vector and its payload using the custom 'id' field."""
# First, find the document by custom ID
search_query = {
"query": {
"term": {
"id": vector_id
}
}
}
search_query = {"query": {"term": {"id": vector_id}}}
response = self.client.search(index=self.collection_name, body=search_query)
hits = response.get("hits", {}).get("hits", [])
@@ -241,7 +210,6 @@ class OpenSearchDB(VectorStoreBase):
except Exception:
pass
def get(self, vector_id: str) -> Optional[OutputData]:
"""Retrieve a vector by ID."""
try:
@@ -251,13 +219,7 @@ class OpenSearchDB(VectorStoreBase):
self.create_col(self.collection_name, self.embedding_model_dims)
return None
search_query = {
"query": {
"term": {
"id": vector_id
}
}
}
search_query = {"query": {"term": {"id": vector_id}}}
response = self.client.search(index=self.collection_name, body=search_query)
hits = response["hits"]["hits"]
@@ -265,11 +227,7 @@ class OpenSearchDB(VectorStoreBase):
if not hits:
return None
return OutputData(
id=hits[0]["_source"].get("id"),
score=1.0,
payload=hits[0]["_source"].get("payload", {})
)
return OutputData(id=hits[0]["_source"].get("id"), score=1.0, payload=hits[0]["_source"].get("payload", {}))
except Exception as e:
logger.error(f"Error retrieving vector {vector_id}: {str(e)}")
return None
@@ -287,30 +245,19 @@ class OpenSearchDB(VectorStoreBase):
return self.client.indices.get(index=name)
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[OutputData]:
try:
"""List all memories with optional filters."""
query: Dict = {
"query": {
"match_all": {}
}
}
query: Dict = {"query": {"match_all": {}}}
filter_clauses = []
if filters:
for key in ["user_id", "run_id", "agent_id"]:
value = filters.get(key)
if value:
filter_clauses.append({
"term": {f"payload.{key}.keyword": value}
})
filter_clauses.append({"term": {f"payload.{key}.keyword": value}})
if filter_clauses:
query["query"] = {
"bool": {
"filter": filter_clauses
}
}
query["query"] = {"bool": {"filter": filter_clauses}}
if limit:
query["size"] = limit
@@ -318,18 +265,15 @@ class OpenSearchDB(VectorStoreBase):
response = self.client.search(index=self.collection_name, body=query)
hits = response["hits"]["hits"]
return [[
OutputData(
id=hit["_source"].get("id"),
score=1.0,
payload=hit["_source"].get("payload", {})
)
for hit in hits
]]
return [
[
OutputData(id=hit["_source"].get("id"), score=1.0, payload=hit["_source"].get("payload", {}))
for hit in hits
]
]
except Exception:
return []
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")

View File

@@ -286,7 +286,7 @@ class PGVector(VectorStoreBase):
self.cur.close()
if hasattr(self, "conn"):
self.conn.close()
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")

View File

@@ -232,7 +232,7 @@ class Qdrant(VectorStoreBase):
with_vectors=False,
)
return result
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")

View File

@@ -88,7 +88,7 @@ class RedisDB(VectorStoreBase):
The created index object.
"""
# Use provided parameters or fall back to instance attributes
collection_name = name or self.schema['index']['name']
collection_name = name or self.schema["index"]["name"]
embedding_dims = vector_size or self.embedding_model_dims
distance_metric = distance or "cosine"
@@ -237,17 +237,16 @@ class RedisDB(VectorStoreBase):
"""
Reset the index by deleting and recreating it.
"""
collection_name = self.schema['index']['name']
collection_name = self.schema["index"]["name"]
logger.warning(f"Resetting index {collection_name}...")
self.delete_col()
self.index = SearchIndex.from_dict(self.schema)
self.index.set_client(self.client)
self.index.create(overwrite=True)
#or use
#self.create_col(collection_name, self.embedding_model_dims)
# or use
# self.create_col(collection_name, self.embedding_model_dims)
# Recreate the index with the same parameters
self.create_col(collection_name, self.embedding_model_dims)

View File

@@ -229,7 +229,7 @@ class Supabase(VectorStoreBase):
records = self.collection.fetch(ids=ids)
return [[OutputData(id=str(record[0]), score=None, payload=record[2]) for record in records]]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")

View File

@@ -285,10 +285,9 @@ class UpstashVector(VectorStoreBase):
- Per-namespace vector and pending vector counts
"""
return self.client.info()
def reset(self):
"""
Reset the Upstash Vector index.
"""
self.delete_col()

View File

@@ -308,7 +308,7 @@ class Weaviate(VectorStoreBase):
payload["id"] = str(obj.uuid).split("'")[0]
results.append(OutputData(id=str(obj.uuid).split("'")[0], score=1.0, payload=payload))
return [results]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")

View File

@@ -44,31 +44,14 @@ DEFAULT_CONFIG = {
"user": POSTGRES_USER,
"password": POSTGRES_PASSWORD,
"collection_name": POSTGRES_COLLECTION_NAME,
}
},
},
"graph_store": {
"provider": "neo4j",
"config": {
"url": NEO4J_URI,
"username": NEO4J_USERNAME,
"password": NEO4J_PASSWORD
}
},
"llm": {
"provider": "openai",
"config": {
"api_key": OPENAI_API_KEY,
"temperature": 0.2,
"model": "gpt-4o"
}
},
"embedder": {
"provider": "openai",
"config": {
"api_key": OPENAI_API_KEY,
"model": "text-embedding-3-small"
}
"config": {"url": NEO4J_URI, "username": NEO4J_USERNAME, "password": NEO4J_PASSWORD},
},
"llm": {"provider": "openai", "config": {"api_key": OPENAI_API_KEY, "temperature": 0.2, "model": "gpt-4o"}},
"embedder": {"provider": "openai", "config": {"api_key": OPENAI_API_KEY, "model": "text-embedding-3-small"}},
"history_db_path": HISTORY_DB_PATH,
}
@@ -115,9 +98,7 @@ def set_config(config: Dict[str, Any]):
def add_memory(memory_create: MemoryCreate):
"""Store new memories."""
if not any([memory_create.user_id, memory_create.agent_id, memory_create.run_id]):
raise HTTPException(
status_code=400, detail="At least one identifier (user_id, agent_id, run_id) is required."
)
raise HTTPException(status_code=400, detail="At least one identifier (user_id, agent_id, run_id) is required.")
params = {k: v for k, v in memory_create.model_dump().items() if v is not None and k != "messages"}
try:
@@ -138,7 +119,9 @@ def get_all_memories(
if not any([user_id, run_id, agent_id]):
raise HTTPException(status_code=400, detail="At least one identifier is required.")
try:
params = {k: v for k, v in {"user_id": user_id, "run_id": run_id, "agent_id": agent_id}.items() if v is not None}
params = {
k: v for k, v in {"user_id": user_id, "run_id": run_id, "agent_id": agent_id}.items() if v is not None
}
return MEMORY_INSTANCE.get_all(**params)
except Exception as e:
logging.exception("Error in get_all_memories:")
@@ -207,7 +190,9 @@ def delete_all_memories(
if not any([user_id, run_id, agent_id]):
raise HTTPException(status_code=400, detail="At least one identifier is required.")
try:
params = {k: v for k, v in {"user_id": user_id, "run_id": run_id, "agent_id": agent_id}.items() if v is not None}
params = {
k: v for k, v in {"user_id": user_id, "run_id": run_id, "agent_id": agent_id}.items() if v is not None
}
MEMORY_INSTANCE.delete_all(**params)
return {"message": "All relevant memories deleted"}
except Exception as e:
@@ -229,4 +214,4 @@ def reset_memory():
@app.get("/", summary="Redirect to the OpenAPI documentation", include_in_schema=False)
def home():
"""Redirect to the OpenAPI documentation."""
return RedirectResponse(url='/docs')
return RedirectResponse(url="/docs")

View File

@@ -5,13 +5,15 @@ def test_get_update_memory_messages():
retrieved_old_memory_dict = [{"id": "1", "text": "old memory 1"}]
response_content = ["new fact"]
custom_update_memory_prompt = "custom prompt determining memory update"
## When custom update memory prompt is provided
##
result = prompts.get_update_memory_messages(retrieved_old_memory_dict, response_content, custom_update_memory_prompt)
result = prompts.get_update_memory_messages(
retrieved_old_memory_dict, response_content, custom_update_memory_prompt
)
assert result.startswith(custom_update_memory_prompt)
## When custom update memory prompt is not provided
##
result = prompts.get_update_memory_messages(retrieved_old_memory_dict, response_content, None)
assert result.startswith(prompts.DEFAULT_UPDATE_MEMORY_PROMPT)
assert result.startswith(prompts.DEFAULT_UPDATE_MEMORY_PROMPT)

View File

@@ -10,9 +10,7 @@ from mem0.embeddings.lmstudio import LMStudioEmbedding
def mock_lm_studio_client():
with patch("mem0.embeddings.lmstudio.OpenAI") as mock_openai:
mock_client = Mock()
mock_client.embeddings.create.return_value = Mock(
data=[Mock(embedding=[0.1, 0.2, 0.3, 0.4, 0.5])]
)
mock_client.embeddings.create.return_value = Mock(data=[Mock(embedding=[0.1, 0.2, 0.3, 0.4, 0.5])])
mock_openai.return_value = mock_client
yield mock_client

View File

@@ -23,7 +23,9 @@ def test_embed_default_model(mock_openai_client):
result = embedder.embed("Hello world")
mock_openai_client.embeddings.create.assert_called_once_with(input=["Hello world"], model="text-embedding-3-small", dimensions = 1536)
mock_openai_client.embeddings.create.assert_called_once_with(
input=["Hello world"], model="text-embedding-3-small", dimensions=1536
)
assert result == [0.1, 0.2, 0.3]
@@ -37,7 +39,7 @@ def test_embed_custom_model(mock_openai_client):
result = embedder.embed("Test embedding")
mock_openai_client.embeddings.create.assert_called_once_with(
input=["Test embedding"], model="text-embedding-2-medium", dimensions = 1024
input=["Test embedding"], model="text-embedding-2-medium", dimensions=1024
)
assert result == [0.4, 0.5, 0.6]
@@ -51,7 +53,9 @@ def test_embed_removes_newlines(mock_openai_client):
result = embedder.embed("Hello\nworld")
mock_openai_client.embeddings.create.assert_called_once_with(input=["Hello world"], model="text-embedding-3-small", dimensions = 1536)
mock_openai_client.embeddings.create.assert_called_once_with(
input=["Hello world"], model="text-embedding-3-small", dimensions=1536
)
assert result == [0.7, 0.8, 0.9]
@@ -65,7 +69,7 @@ def test_embed_without_api_key_env_var(mock_openai_client):
result = embedder.embed("Testing API key")
mock_openai_client.embeddings.create.assert_called_once_with(
input=["Testing API key"], model="text-embedding-3-small", dimensions = 1536
input=["Testing API key"], model="text-embedding-3-small", dimensions=1536
)
assert result == [1.0, 1.1, 1.2]
@@ -81,6 +85,6 @@ def test_embed_uses_environment_api_key(mock_openai_client, monkeypatch):
result = embedder.embed("Environment key test")
mock_openai_client.embeddings.create.assert_called_once_with(
input=["Environment key test"], model="text-embedding-3-small", dimensions = 1536
input=["Environment key test"], model="text-embedding-3-small", dimensions=1536
)
assert result == [1.3, 1.4, 1.5]

View File

@@ -24,11 +24,20 @@ def mock_config():
with patch("mem0.configs.embeddings.base.BaseEmbedderConfig") as mock_config:
mock_config.return_value.vertex_credentials_json = "/path/to/credentials.json"
yield mock_config
@pytest.fixture
def mock_embedding_types():
return ["SEMANTIC_SIMILARITY", "CLASSIFICATION", "CLUSTERING", "RETRIEVAL_DOCUMENT", "RETRIEVAL_QUERY", "QUESTION_ANSWERING", "FACT_VERIFICATION", "CODE_RETRIEVAL_QUERY"]
return [
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"RETRIEVAL_DOCUMENT",
"RETRIEVAL_QUERY",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
"CODE_RETRIEVAL_QUERY",
]
@pytest.fixture
@@ -79,30 +88,31 @@ def test_embed_custom_model(mock_text_embedding_model, mock_os_environ, mock_con
assert result == [0.4, 0.5, 0.6]
@patch("mem0.embeddings.vertexai.TextEmbeddingModel")
def test_embed_with_memory_action(mock_text_embedding_model, mock_os_environ, mock_config, mock_embedding_types, mock_text_embedding_input):
@patch("mem0.embeddings.vertexai.TextEmbeddingModel")
def test_embed_with_memory_action(
mock_text_embedding_model, mock_os_environ, mock_config, mock_embedding_types, mock_text_embedding_input
):
mock_config.return_value.model = "text-embedding-004"
mock_config.return_value.embedding_dims = 256
for embedding_type in mock_embedding_types:
mock_config.return_value.memory_add_embedding_type = embedding_type
mock_config.return_value.memory_update_embedding_type = embedding_type
mock_config.return_value.memory_search_embedding_type = embedding_type
config = mock_config()
embedder = VertexAIEmbedding(config)
mock_text_embedding_model.from_pretrained.assert_called_with("text-embedding-004")
for memory_action in ["add", "update", "search"]:
embedder.embed("Hello world", memory_action=memory_action)
mock_text_embedding_input.assert_called_with(text="Hello world", task_type=embedding_type)
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.assert_called_with(
texts=[mock_text_embedding_input("Hello world", embedding_type)], output_dimensionality=256
)
@patch("mem0.embeddings.vertexai.os")
def test_credentials_from_environment(mock_os, mock_text_embedding_model, mock_config):
@@ -137,15 +147,15 @@ def test_embed_with_different_dimensions(mock_text_embedding_model, mock_os_envi
result = embedder.embed("Large embedding test")
assert result == [0.1] * 1024
@patch("mem0.embeddings.vertexai.TextEmbeddingModel")
@patch("mem0.embeddings.vertexai.TextEmbeddingModel")
def test_invalid_memory_action(mock_text_embedding_model, mock_config):
mock_config.return_value.model = "text-embedding-004"
mock_config.return_value.embedding_dims = 256
config = mock_config()
embedder = VertexAIEmbedding(config)
with pytest.raises(ValueError):
embedder.embed("Hello world", memory_action="invalid_action")
embedder.embed("Hello world", memory_action="invalid_action")

View File

@@ -127,4 +127,4 @@ def test_generate_with_http_proxies(default_headers):
api_version=None,
default_headers=default_headers,
)
mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000")
mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000")

View File

@@ -31,12 +31,12 @@ def test_deepseek_llm_base_url():
# case3: with config.deepseek_base_url
config_base_url = "https://api.config.com/v1/"
config = BaseLlmConfig(
model="deepseek-chat",
temperature=0.7,
max_tokens=100,
top_p=1.0,
api_key="api_key",
deepseek_base_url=config_base_url
model="deepseek-chat",
temperature=0.7,
max_tokens=100,
top_p=1.0,
api_key="api_key",
deepseek_base_url=config_base_url,
)
llm = DeepSeekLLM(config)
assert str(llm.client.base_url) == config_base_url
@@ -99,16 +99,16 @@ def test_generate_response_with_tools(mock_deepseek_client):
response = llm.generate_response(messages, tools=tools)
mock_deepseek_client.chat.completions.create.assert_called_once_with(
model="deepseek-chat",
messages=messages,
temperature=0.7,
max_tokens=100,
top_p=1.0,
tools=tools,
tool_choice="auto"
model="deepseek-chat",
messages=messages,
temperature=0.7,
max_tokens=100,
top_p=1.0,
tools=tools,
tool_choice="auto",
)
assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1
assert response["tool_calls"][0]["name"] == "add_memory"
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}

View File

@@ -10,6 +10,7 @@ try:
from langchain.chat_models.base import BaseChatModel
except ImportError:
from unittest.mock import MagicMock
BaseChatModel = MagicMock
@@ -24,16 +25,11 @@ def mock_langchain_model():
def test_langchain_initialization(mock_langchain_model):
"""Test that LangchainLLM initializes correctly with a valid model."""
# Create a config with the model instance directly
config = BaseLlmConfig(
model=mock_langchain_model,
temperature=0.7,
max_tokens=100,
api_key="test-api-key"
)
config = BaseLlmConfig(model=mock_langchain_model, temperature=0.7, max_tokens=100, api_key="test-api-key")
# Initialize the LangchainLLM
llm = LangchainLLM(config)
# Verify the model was correctly assigned
assert llm.langchain_model == mock_langchain_model
@@ -41,35 +37,30 @@ def test_langchain_initialization(mock_langchain_model):
def test_generate_response(mock_langchain_model):
"""Test that generate_response correctly processes messages and returns a response."""
# Create a config with the model instance
config = BaseLlmConfig(
model=mock_langchain_model,
temperature=0.7,
max_tokens=100,
api_key="test-api-key"
)
config = BaseLlmConfig(model=mock_langchain_model, temperature=0.7, max_tokens=100, api_key="test-api-key")
# Initialize the LangchainLLM
llm = LangchainLLM(config)
# Create test messages
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well! How can I help you?"},
{"role": "user", "content": "Tell me a joke."}
{"role": "user", "content": "Tell me a joke."},
]
# Get response
response = llm.generate_response(messages)
# Verify the correct message format was passed to the model
expected_langchain_messages = [
("system", "You are a helpful assistant."),
("human", "Hello, how are you?"),
("ai", "I'm doing well! How can I help you?"),
("human", "Tell me a joke.")
("human", "Tell me a joke."),
]
mock_langchain_model.invoke.assert_called_once()
# Extract the first argument of the first call
actual_messages = mock_langchain_model.invoke.call_args[0][0]
@@ -79,25 +70,15 @@ def test_generate_response(mock_langchain_model):
def test_invalid_model():
"""Test that LangchainLLM raises an error with an invalid model."""
config = BaseLlmConfig(
model="not-a-valid-model-instance",
temperature=0.7,
max_tokens=100,
api_key="test-api-key"
)
config = BaseLlmConfig(model="not-a-valid-model-instance", temperature=0.7, max_tokens=100, api_key="test-api-key")
with pytest.raises(ValueError, match="`model` must be an instance of BaseChatModel"):
LangchainLLM(config)
def test_missing_model():
"""Test that LangchainLLM raises an error when model is None."""
config = BaseLlmConfig(
model=None,
temperature=0.7,
max_tokens=100,
api_key="test-api-key"
)
config = BaseLlmConfig(model=None, temperature=0.7, max_tokens=100, api_key="test-api-key")
with pytest.raises(ValueError, match="`model` parameter is required"):
LangchainLLM(config)

View File

@@ -11,9 +11,7 @@ def mock_lm_studio_client():
with patch("mem0.llms.lmstudio.OpenAI") as mock_openai: # Corrected path
mock_client = Mock()
mock_client.chat.completions.create.return_value = Mock(
choices=[
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
]
choices=[Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
)
mock_openai.return_value = mock_client
yield mock_client

View File

@@ -10,18 +10,19 @@ def _setup_mocks(mocker):
"""Helper to setup common mocks for both sync and async fixtures"""
mock_embedder = mocker.MagicMock()
mock_embedder.return_value.embed.return_value = [0.1, 0.2, 0.3]
mocker.patch('mem0.utils.factory.EmbedderFactory.create', mock_embedder)
mocker.patch("mem0.utils.factory.EmbedderFactory.create", mock_embedder)
mock_vector_store = mocker.MagicMock()
mock_vector_store.return_value.search.return_value = []
mocker.patch('mem0.utils.factory.VectorStoreFactory.create',
side_effect=[mock_vector_store.return_value, mocker.MagicMock()])
mocker.patch(
"mem0.utils.factory.VectorStoreFactory.create", side_effect=[mock_vector_store.return_value, mocker.MagicMock()]
)
mock_llm = mocker.MagicMock()
mocker.patch('mem0.utils.factory.LlmFactory.create', mock_llm)
mocker.patch('mem0.memory.storage.SQLiteManager', mocker.MagicMock())
mocker.patch("mem0.utils.factory.LlmFactory.create", mock_llm)
mocker.patch("mem0.memory.storage.SQLiteManager", mocker.MagicMock())
return mock_llm, mock_vector_store
@@ -30,29 +31,26 @@ class TestAddToVectorStoreErrors:
def mock_memory(self, mocker):
"""Fixture that returns a Memory instance with mocker-based mocks"""
mock_llm, _ = _setup_mocks(mocker)
memory = Memory()
memory.config = mocker.MagicMock()
memory.config.custom_fact_extraction_prompt = None
memory.config.custom_update_memory_prompt = None
memory.api_version = "v1.1"
return memory
def test_empty_llm_response_fact_extraction(self, mock_memory, caplog):
"""Test empty response from LLM during fact extraction"""
# Setup
mock_memory.llm.generate_response.return_value = ""
# Execute
with caplog.at_level(logging.ERROR):
result = mock_memory._add_to_vector_store(
messages=[{"role": "user", "content": "test"}],
metadata={},
filters={},
infer=True
messages=[{"role": "user", "content": "test"}], metadata={}, filters={}, infer=True
)
# Verify
assert mock_memory.llm.generate_response.call_count == 2
assert result == [] # Should return empty list when no memories processed
@@ -62,20 +60,14 @@ class TestAddToVectorStoreErrors:
"""Test empty response from LLM during memory actions"""
# Setup
# First call returns valid JSON, second call returns empty string
mock_memory.llm.generate_response.side_effect = [
'{"facts": ["test fact"]}',
""
]
mock_memory.llm.generate_response.side_effect = ['{"facts": ["test fact"]}', ""]
# Execute
with caplog.at_level(logging.ERROR):
result = mock_memory._add_to_vector_store(
messages=[{"role": "user", "content": "test"}],
metadata={},
filters={},
infer=True
messages=[{"role": "user", "content": "test"}], metadata={}, filters={}, infer=True
)
# Verify
assert mock_memory.llm.generate_response.call_count == 2
assert result == [] # Should return empty list when no memories processed
@@ -88,48 +80,39 @@ class TestAsyncAddToVectorStoreErrors:
def mock_async_memory(self, mocker):
"""Fixture for AsyncMemory with mocker-based mocks"""
mock_llm, _ = _setup_mocks(mocker)
memory = AsyncMemory()
memory.config = mocker.MagicMock()
memory.config.custom_fact_extraction_prompt = None
memory.config.custom_update_memory_prompt = None
memory.api_version = "v1.1"
return memory
@pytest.mark.asyncio
async def test_async_empty_llm_response_fact_extraction(self, mock_async_memory, caplog, mocker):
"""Test empty response in AsyncMemory._add_to_vector_store"""
mocker.patch('mem0.utils.factory.EmbedderFactory.create', return_value=MagicMock())
mocker.patch("mem0.utils.factory.EmbedderFactory.create", return_value=MagicMock())
mock_async_memory.llm.generate_response.return_value = ""
with caplog.at_level(logging.ERROR):
result = await mock_async_memory._add_to_vector_store(
messages=[{"role": "user", "content": "test"}],
metadata={},
filters={},
infer=True
messages=[{"role": "user", "content": "test"}], metadata={}, filters={}, infer=True
)
assert result == []
assert "Error in new_retrieved_facts" in caplog.text
@pytest.mark.asyncio
async def test_async_empty_llm_response_memory_actions(self, mock_async_memory, caplog, mocker):
"""Test empty response in AsyncMemory._add_to_vector_store"""
mocker.patch('mem0.utils.factory.EmbedderFactory.create', return_value=MagicMock())
mock_async_memory.llm.generate_response.side_effect = [
'{"facts": ["test fact"]}',
""
]
mocker.patch("mem0.utils.factory.EmbedderFactory.create", return_value=MagicMock())
mock_async_memory.llm.generate_response.side_effect = ['{"facts": ["test fact"]}', ""]
with caplog.at_level(logging.ERROR):
result = await mock_async_memory._add_to_vector_store(
messages=[{"role": "user", "content": "test"}],
metadata={},
filters={},
infer=True
messages=[{"role": "user", "content": "test"}], metadata={}, filters={}, infer=True
)
assert result == []
assert "Invalid JSON response" in caplog.text

View File

@@ -17,11 +17,13 @@ def mock_openai():
@pytest.fixture
def memory_instance():
with patch("mem0.utils.factory.EmbedderFactory") as mock_embedder, patch(
"mem0.utils.factory.VectorStoreFactory"
) as mock_vector_store, patch("mem0.utils.factory.LlmFactory") as mock_llm, patch(
"mem0.memory.telemetry.capture_event"
), patch("mem0.memory.graph_memory.MemoryGraph"):
with (
patch("mem0.utils.factory.EmbedderFactory") as mock_embedder,
patch("mem0.utils.factory.VectorStoreFactory") as mock_vector_store,
patch("mem0.utils.factory.LlmFactory") as mock_llm,
patch("mem0.memory.telemetry.capture_event"),
patch("mem0.memory.graph_memory.MemoryGraph"),
):
mock_embedder.create.return_value = Mock()
mock_vector_store.create.return_value = Mock()
mock_llm.create.return_value = Mock()
@@ -30,13 +32,16 @@ def memory_instance():
config.graph_store.config = {"some_config": "value"}
return Memory(config)
@pytest.fixture
def memory_custom_instance():
with patch("mem0.utils.factory.EmbedderFactory") as mock_embedder, patch(
"mem0.utils.factory.VectorStoreFactory"
) as mock_vector_store, patch("mem0.utils.factory.LlmFactory") as mock_llm, patch(
"mem0.memory.telemetry.capture_event"
), patch("mem0.memory.graph_memory.MemoryGraph"):
with (
patch("mem0.utils.factory.EmbedderFactory") as mock_embedder,
patch("mem0.utils.factory.VectorStoreFactory") as mock_vector_store,
patch("mem0.utils.factory.LlmFactory") as mock_llm,
patch("mem0.memory.telemetry.capture_event"),
patch("mem0.memory.graph_memory.MemoryGraph"),
):
mock_embedder.create.return_value = Mock()
mock_vector_store.create.return_value = Mock()
mock_llm.create.return_value = Mock()
@@ -44,7 +49,7 @@ def memory_custom_instance():
config = MemoryConfig(
version="v1.1",
custom_fact_extraction_prompt="custom prompt extracting memory",
custom_update_memory_prompt="custom prompt determining memory update"
custom_update_memory_prompt="custom prompt determining memory update",
)
config.graph_store.config = {"some_config": "value"}
return Memory(config)
@@ -194,7 +199,6 @@ def test_delete_all(memory_instance, version, enable_graph):
assert result["message"] == "Memories deleted successfully!"
@pytest.mark.parametrize(
"version, enable_graph, expected_result",
[
@@ -242,20 +246,22 @@ def test_get_all(memory_instance, version, enable_graph, expected_result):
memory_instance.graph.get_all.assert_called_once_with({"user_id": "test_user"}, 100)
else:
memory_instance.graph.get_all.assert_not_called()
def test_custom_prompts(memory_custom_instance):
messages = [{"role": "user", "content": "Test message"}]
memory_custom_instance.llm.generate_response = Mock()
with patch("mem0.memory.main.parse_messages", return_value="Test message") as mock_parse_messages:
with patch("mem0.memory.main.get_update_memory_messages", return_value="custom update memory prompt") as mock_get_update_memory_messages:
with patch(
"mem0.memory.main.get_update_memory_messages", return_value="custom update memory prompt"
) as mock_get_update_memory_messages:
memory_custom_instance.add(messages=messages, user_id="test_user")
## custom prompt
##
mock_parse_messages.assert_called_once_with(messages)
memory_custom_instance.llm.generate_response.assert_any_call(
messages=[
{"role": "system", "content": memory_custom_instance.config.custom_fact_extraction_prompt},
@@ -263,12 +269,14 @@ def test_custom_prompts(memory_custom_instance):
],
response_format={"type": "json_object"},
)
## custom update memory prompt
##
mock_get_update_memory_messages.assert_called_once_with([],[],memory_custom_instance.config.custom_update_memory_prompt)
mock_get_update_memory_messages.assert_called_once_with(
[], [], memory_custom_instance.config.custom_update_memory_prompt
)
memory_custom_instance.llm.generate_response.assert_any_call(
messages=[{"role": "user", "content": mock_get_update_memory_messages.return_value}],
response_format={"type": "json_object"},
)
)

View File

@@ -97,4 +97,4 @@ def test_completions_create_with_system_message(mock_memory_client, mock_litellm
call_args = mock_litellm.completion.call_args[1]
assert call_args["messages"][0]["role"] == "system"
assert call_args["messages"][0]["content"] == "You are a helpful assistant."
assert call_args["messages"][0]["content"] == "You are a helpful assistant."

View File

@@ -13,13 +13,15 @@ from mem0.vector_stores.azure_ai_search import AzureAISearch
# Fixture to patch SearchClient and SearchIndexClient and create an instance of AzureAISearch.
@pytest.fixture
def mock_clients():
with patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient, \
patch("mem0.vector_stores.azure_ai_search.SearchIndexClient") as MockIndexClient, \
patch("mem0.vector_stores.azure_ai_search.AzureKeyCredential") as MockAzureKeyCredential:
with (
patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient,
patch("mem0.vector_stores.azure_ai_search.SearchIndexClient") as MockIndexClient,
patch("mem0.vector_stores.azure_ai_search.AzureKeyCredential") as MockAzureKeyCredential,
):
# Create mocked instances for search and index clients.
mock_search_client = MockSearchClient.return_value
mock_index_client = MockIndexClient.return_value
# Mock the client._client._config.user_agent_policy.add_user_agent
mock_search_client._client = MagicMock()
mock_search_client._client._config.user_agent_policy.add_user_agent = Mock()
@@ -62,7 +64,7 @@ def azure_ai_search_instance(mock_clients):
api_key="test-api-key",
embedding_model_dims=3,
compression_type="binary", # testing binary quantization option
use_float16=True
use_float16=True,
)
# Return instance and clients for verification.
return instance, mock_search_client, mock_index_client
@@ -70,21 +72,18 @@ def azure_ai_search_instance(mock_clients):
# --- Tests for AzureAISearchConfig ---
def test_config_validation_valid():
"""Test valid configurations are accepted."""
# Test minimal configuration
config = AzureAISearchConfig(
service_name="test-service",
api_key="test-api-key",
embedding_model_dims=768
)
config = AzureAISearchConfig(service_name="test-service", api_key="test-api-key", embedding_model_dims=768)
assert config.collection_name == "mem0" # Default value
assert config.service_name == "test-service"
assert config.api_key == "test-api-key"
assert config.embedding_model_dims == 768
assert config.compression_type is None
assert config.use_float16 is False
# Test with all optional parameters
config = AzureAISearchConfig(
collection_name="custom-index",
@@ -92,7 +91,7 @@ def test_config_validation_valid():
api_key="test-api-key",
embedding_model_dims=1536,
compression_type="scalar",
use_float16=True
use_float16=True,
)
assert config.collection_name == "custom-index"
assert config.compression_type == "scalar"
@@ -106,7 +105,7 @@ def test_config_validation_invalid_compression_type():
service_name="test-service",
api_key="test-api-key",
embedding_model_dims=768,
compression_type="invalid-type" # Not a valid option
compression_type="invalid-type", # Not a valid option
)
assert "Invalid compression_type" in str(exc_info.value)
@@ -118,7 +117,7 @@ def test_config_validation_deprecated_use_compression():
service_name="test-service",
api_key="test-api-key",
embedding_model_dims=768,
use_compression=True # Deprecated parameter
use_compression=True, # Deprecated parameter
)
# Fix: Use a partial string match instead of exact match
assert "use_compression" in str(exc_info.value)
@@ -132,7 +131,7 @@ def test_config_validation_extra_fields():
service_name="test-service",
api_key="test-api-key",
embedding_model_dims=768,
unknown_parameter="value" # Extra field
unknown_parameter="value", # Extra field
)
assert "Extra fields not allowed" in str(exc_info.value)
assert "unknown_parameter" in str(exc_info.value)
@@ -140,30 +139,28 @@ def test_config_validation_extra_fields():
# --- Tests for AzureAISearch initialization ---
def test_initialization(mock_clients):
"""Test AzureAISearch initialization with different parameters."""
mock_search_client, mock_index_client, mock_azure_key_credential = mock_clients
# Test with minimal parameters
instance = AzureAISearch(
service_name="test-service",
collection_name="test-index",
api_key="test-api-key",
embedding_model_dims=768
service_name="test-service", collection_name="test-index", api_key="test-api-key", embedding_model_dims=768
)
# Verify initialization parameters
assert instance.index_name == "test-index"
assert instance.collection_name == "test-index"
assert instance.embedding_model_dims == 768
assert instance.compression_type == "none" # Default when None is passed
assert instance.use_float16 is False
# Verify client creation
mock_azure_key_credential.assert_called_with("test-api-key")
assert "mem0" in mock_search_client._client._config.user_agent_policy.add_user_agent.call_args[0]
assert "mem0" in mock_index_client._client._config.user_agent_policy.add_user_agent.call_args[0]
# Verify index creation was called
mock_index_client.create_or_update_index.assert_called_once()
@@ -171,75 +168,75 @@ def test_initialization(mock_clients):
def test_initialization_with_compression_types(mock_clients):
"""Test initialization with different compression types."""
mock_search_client, mock_index_client, _ = mock_clients
# Test with scalar compression
instance = AzureAISearch(
service_name="test-service",
collection_name="scalar-index",
api_key="test-api-key",
embedding_model_dims=768,
compression_type="scalar"
compression_type="scalar",
)
assert instance.compression_type == "scalar"
# Capture the index creation call
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
index = args[0]
# Verify scalar compression was configured
assert hasattr(index.vector_search, 'compressions')
assert hasattr(index.vector_search, "compressions")
assert len(index.vector_search.compressions) > 0
assert "ScalarQuantizationCompression" in str(type(index.vector_search.compressions[0]))
# Test with binary compression
instance = AzureAISearch(
service_name="test-service",
collection_name="binary-index",
api_key="test-api-key",
embedding_model_dims=768,
compression_type="binary"
compression_type="binary",
)
assert instance.compression_type == "binary"
# Capture the index creation call
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
index = args[0]
# Verify binary compression was configured
assert hasattr(index.vector_search, 'compressions')
assert hasattr(index.vector_search, "compressions")
assert len(index.vector_search.compressions) > 0
assert "BinaryQuantizationCompression" in str(type(index.vector_search.compressions[0]))
# Test with no compression
instance = AzureAISearch(
service_name="test-service",
collection_name="no-compression-index",
api_key="test-api-key",
embedding_model_dims=768,
compression_type=None
compression_type=None,
)
assert instance.compression_type == "none"
# Capture the index creation call
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
index = args[0]
# Verify no compression was configured
assert hasattr(index.vector_search, 'compressions')
assert hasattr(index.vector_search, "compressions")
assert len(index.vector_search.compressions) == 0
def test_initialization_with_float_precision(mock_clients):
"""Test initialization with different float precision settings."""
mock_search_client, mock_index_client, _ = mock_clients
# Test with half precision (float16)
instance = AzureAISearch(
service_name="test-service",
collection_name="float16-index",
api_key="test-api-key",
embedding_model_dims=768,
use_float16=True
use_float16=True,
)
assert instance.use_float16 is True
# Capture the index creation call
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
index = args[0]
@@ -247,17 +244,17 @@ def test_initialization_with_float_precision(mock_clients):
vector_field = next((f for f in index.fields if f.name == "vector"), None)
assert vector_field is not None
assert "Edm.Half" in vector_field.type
# Test with full precision (float32)
instance = AzureAISearch(
service_name="test-service",
collection_name="float32-index",
api_key="test-api-key",
embedding_model_dims=768,
use_float16=False
use_float16=False,
)
assert instance.use_float16 is False
# Capture the index creation call
args, _ = mock_index_client.create_or_update_index.call_args_list[-1]
index = args[0]
@@ -269,21 +266,22 @@ def test_initialization_with_float_precision(mock_clients):
# --- Tests for create_col method ---
def test_create_col(azure_ai_search_instance):
"""Test the create_col method creates an index with the correct configuration."""
instance, _, mock_index_client = azure_ai_search_instance
# create_col is called during initialization, so we check the call that was already made
mock_index_client.create_or_update_index.assert_called_once()
# Verify the index configuration
args, _ = mock_index_client.create_or_update_index.call_args
index = args[0]
# Check basic properties
assert index.name == "test-index"
assert len(index.fields) == 6 # id, user_id, run_id, agent_id, vector, payload
# Check that required fields are present
field_names = [f.name for f in index.fields]
assert "id" in field_names
@@ -292,22 +290,22 @@ def test_create_col(azure_ai_search_instance):
assert "user_id" in field_names
assert "run_id" in field_names
assert "agent_id" in field_names
# Check that id is the key field
id_field = next(f for f in index.fields if f.name == "id")
assert id_field.key is True
# Check vector search configuration
assert index.vector_search is not None
assert len(index.vector_search.profiles) == 1
assert index.vector_search.profiles[0].name == "my-vector-config"
assert index.vector_search.profiles[0].algorithm_configuration_name == "my-algorithms-config"
# Check algorithms
assert len(index.vector_search.algorithms) == 1
assert index.vector_search.algorithms[0].name == "my-algorithms-config"
assert "HnswAlgorithmConfiguration" in str(type(index.vector_search.algorithms[0]))
# With binary compression and float16, we should have compression configuration
assert len(index.vector_search.compressions) == 1
assert index.vector_search.compressions[0].compression_name == "myCompression"
@@ -317,24 +315,24 @@ def test_create_col(azure_ai_search_instance):
def test_create_col_scalar_compression(mock_clients):
"""Test creating a collection with scalar compression."""
mock_search_client, mock_index_client, _ = mock_clients
AzureAISearch(
service_name="test-service",
collection_name="scalar-index",
api_key="test-api-key",
embedding_model_dims=768,
compression_type="scalar"
compression_type="scalar",
)
# Verify the index configuration
args, _ = mock_index_client.create_or_update_index.call_args
index = args[0]
# Check compression configuration
assert len(index.vector_search.compressions) == 1
assert index.vector_search.compressions[0].compression_name == "myCompression"
assert "ScalarQuantizationCompression" in str(type(index.vector_search.compressions[0]))
# Check profile references compression
assert index.vector_search.profiles[0].compression_name == "myCompression"
@@ -342,28 +340,29 @@ def test_create_col_scalar_compression(mock_clients):
def test_create_col_no_compression(mock_clients):
"""Test creating a collection with no compression."""
mock_search_client, mock_index_client, _ = mock_clients
AzureAISearch(
service_name="test-service",
collection_name="no-compression-index",
api_key="test-api-key",
embedding_model_dims=768,
compression_type=None
compression_type=None,
)
# Verify the index configuration
args, _ = mock_index_client.create_or_update_index.call_args
index = args[0]
# Check compression configuration - should be empty
assert len(index.vector_search.compressions) == 0
# Check profile doesn't reference compression
assert index.vector_search.profiles[0].compression_name is None
# --- Tests for insert method ---
def test_insert_single(azure_ai_search_instance):
"""Test inserting a single vector."""
instance, mock_search_client, _ = azure_ai_search_instance
@@ -372,9 +371,7 @@ def test_insert_single(azure_ai_search_instance):
ids = ["doc1"]
# Fix: Include status_code: 201 in mock response
mock_search_client.upload_documents.return_value = [
{"status": True, "id": "doc1", "status_code": 201}
]
mock_search_client.upload_documents.return_value = [{"status": True, "id": "doc1", "status_code": 201}]
instance.insert(vectors, payloads, ids)
@@ -396,35 +393,35 @@ def test_insert_single(azure_ai_search_instance):
def test_insert_multiple(azure_ai_search_instance):
"""Test inserting multiple vectors in one call."""
instance, mock_search_client, _ = azure_ai_search_instance
# Create multiple vectors
num_docs = 3
vectors = [[float(i)/10, float(i+1)/10, float(i+2)/10] for i in range(num_docs)]
vectors = [[float(i) / 10, float(i + 1) / 10, float(i + 2) / 10] for i in range(num_docs)]
payloads = [{"user_id": f"user{i}", "content": f"Test content {i}"} for i in range(num_docs)]
ids = [f"doc{i}" for i in range(num_docs)]
# Configure mock to return success for all documents (fix: add status_code 201)
mock_search_client.upload_documents.return_value = [
{"status": True, "id": id_val, "status_code": 201} for id_val in ids
]
# Insert the documents
instance.insert(vectors, payloads, ids)
# Verify upload_documents was called with correct documents
mock_search_client.upload_documents.assert_called_once()
args, _ = mock_search_client.upload_documents.call_args
documents = args[0]
# Verify all documents were included
assert len(documents) == num_docs
# Check first document
assert documents[0]["id"] == "doc0"
assert documents[0]["vector"] == [0.0, 0.1, 0.2]
assert documents[0]["payload"] == json.dumps(payloads[0])
assert documents[0]["user_id"] == "user0"
# Check last document
assert documents[2]["id"] == "doc2"
assert documents[2]["vector"] == [0.2, 0.3, 0.4]
@@ -437,9 +434,7 @@ def test_insert_with_error(azure_ai_search_instance):
instance, mock_search_client, _ = azure_ai_search_instance
# Configure mock to return an error for one document
mock_search_client.upload_documents.return_value = [
{"status": False, "id": "doc1", "errorMessage": "Azure error"}
]
mock_search_client.upload_documents.return_value = [{"status": False, "id": "doc1", "errorMessage": "Azure error"}]
vectors = [[0.1, 0.2, 0.3]]
payloads = [{"user_id": "user1"}]
@@ -454,7 +449,7 @@ def test_insert_with_error(azure_ai_search_instance):
# Configure mock to return mixed success/failure for multiple documents
mock_search_client.upload_documents.return_value = [
{"status": True, "id": "doc1"}, # This should not cause failure
{"status": False, "id": "doc2", "errorMessage": "Azure error"}
{"status": False, "id": "doc2", "errorMessage": "Azure error"},
]
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
@@ -465,8 +460,9 @@ def test_insert_with_error(azure_ai_search_instance):
with pytest.raises(Exception) as exc_info:
instance.insert(vectors, payloads, ids)
assert "Insert failed for document doc2" in str(exc_info.value) or \
"Insert failed for document doc1" in str(exc_info.value)
assert "Insert failed for document doc2" in str(exc_info.value) or "Insert failed for document doc1" in str(
exc_info.value
)
def test_insert_with_missing_payload_fields(azure_ai_search_instance):
@@ -500,23 +496,24 @@ def test_insert_with_missing_payload_fields(azure_ai_search_instance):
def test_insert_with_http_error(azure_ai_search_instance):
"""Test insert when Azure client throws an HTTP error."""
instance, mock_search_client, _ = azure_ai_search_instance
# Configure mock to raise an HttpResponseError
mock_search_client.upload_documents.side_effect = HttpResponseError("Azure service error")
vectors = [[0.1, 0.2, 0.3]]
payloads = [{"user_id": "user1"}]
ids = ["doc1"]
# Insert should propagate the HTTP error
with pytest.raises(HttpResponseError) as exc_info:
instance.insert(vectors, payloads, ids)
assert "Azure service error" in str(exc_info.value)
# --- Tests for search method ---
def test_search_basic(azure_ai_search_instance):
"""Test basic vector search without filters."""
instance, mock_search_client, _ = azure_ai_search_instance
@@ -536,9 +533,7 @@ def test_search_basic(azure_ai_search_instance):
# Search with a vector
query_text = "test query" # Add a query string
query_vector = [0.1, 0.2, 0.3]
results = instance.search(
query_text, query_vector, limit=5
) # Pass the query string
results = instance.search(query_text, query_vector, limit=5) # Pass the query string
# Verify search was called correctly
mock_search_client.search.assert_called_once()

View File

@@ -7,9 +7,7 @@ import dotenv
try:
from elasticsearch import Elasticsearch
except ImportError:
raise ImportError(
"Elasticsearch requires extra dependencies. Install with `pip install elasticsearch`"
) from None
raise ImportError("Elasticsearch requires extra dependencies. Install with `pip install elasticsearch`") from None
from mem0.vector_stores.elasticsearch import ElasticsearchDB, OutputData
@@ -19,20 +17,20 @@ class TestElasticsearchDB(unittest.TestCase):
def setUpClass(cls):
# Load environment variables before any test
dotenv.load_dotenv()
# Save original environment variables
cls.original_env = {
'ES_URL': os.getenv('ES_URL', 'http://localhost:9200'),
'ES_USERNAME': os.getenv('ES_USERNAME', 'test_user'),
'ES_PASSWORD': os.getenv('ES_PASSWORD', 'test_password'),
'ES_CLOUD_ID': os.getenv('ES_CLOUD_ID', 'test_cloud_id')
"ES_URL": os.getenv("ES_URL", "http://localhost:9200"),
"ES_USERNAME": os.getenv("ES_USERNAME", "test_user"),
"ES_PASSWORD": os.getenv("ES_PASSWORD", "test_password"),
"ES_CLOUD_ID": os.getenv("ES_CLOUD_ID", "test_cloud_id"),
}
# Set test environment variables
os.environ['ES_URL'] = 'http://localhost'
os.environ['ES_USERNAME'] = 'test_user'
os.environ['ES_PASSWORD'] = 'test_password'
os.environ["ES_URL"] = "http://localhost"
os.environ["ES_USERNAME"] = "test_user"
os.environ["ES_PASSWORD"] = "test_password"
def setUp(self):
# Create a mock Elasticsearch client with proper attributes
self.client_mock = MagicMock(spec=Elasticsearch)
@@ -41,25 +39,25 @@ class TestElasticsearchDB(unittest.TestCase):
self.client_mock.indices.create = MagicMock()
self.client_mock.indices.delete = MagicMock()
self.client_mock.indices.get_alias = MagicMock()
# Start patches BEFORE creating ElasticsearchDB instance
patcher = patch('mem0.vector_stores.elasticsearch.Elasticsearch', return_value=self.client_mock)
patcher = patch("mem0.vector_stores.elasticsearch.Elasticsearch", return_value=self.client_mock)
self.mock_es = patcher.start()
self.addCleanup(patcher.stop)
# Initialize ElasticsearchDB with test config and auto_create_index=False
self.es_db = ElasticsearchDB(
host=os.getenv('ES_URL'),
host=os.getenv("ES_URL"),
port=9200,
collection_name="test_collection",
embedding_model_dims=1536,
user=os.getenv('ES_USERNAME'),
password=os.getenv('ES_PASSWORD'),
user=os.getenv("ES_USERNAME"),
password=os.getenv("ES_PASSWORD"),
verify_certs=False,
use_ssl=False,
auto_create_index=False # Disable auto creation for tests
auto_create_index=False, # Disable auto creation for tests
)
# Reset mock counts after initialization
self.client_mock.reset_mock()
@@ -80,15 +78,15 @@ class TestElasticsearchDB(unittest.TestCase):
# Test when index doesn't exist
self.client_mock.indices.exists.return_value = False
self.es_db.create_index()
# Verify index creation was called with correct settings
self.client_mock.indices.create.assert_called_once()
create_args = self.client_mock.indices.create.call_args[1]
# Verify basic index settings
self.assertEqual(create_args["index"], "test_collection")
self.assertIn("mappings", create_args["body"])
# Verify field mappings
mappings = create_args["body"]["mappings"]["properties"]
self.assertEqual(mappings["text"]["type"], "text")
@@ -97,53 +95,53 @@ class TestElasticsearchDB(unittest.TestCase):
self.assertEqual(mappings["vector"]["index"], True)
self.assertEqual(mappings["vector"]["similarity"], "cosine")
self.assertEqual(mappings["metadata"]["type"], "object")
# Reset mocks for next test
self.client_mock.reset_mock()
# Test when index already exists
self.client_mock.indices.exists.return_value = True
self.es_db.create_index()
# Verify create was not called when index exists
self.client_mock.indices.create.assert_not_called()
def test_auto_create_index(self):
# Reset mock
self.client_mock.reset_mock()
# Test with auto_create_index=True
ElasticsearchDB(
host=os.getenv('ES_URL'),
host=os.getenv("ES_URL"),
port=9200,
collection_name="test_collection",
embedding_model_dims=1536,
user=os.getenv('ES_USERNAME'),
password=os.getenv('ES_PASSWORD'),
user=os.getenv("ES_USERNAME"),
password=os.getenv("ES_PASSWORD"),
verify_certs=False,
use_ssl=False,
auto_create_index=True
auto_create_index=True,
)
# Verify create_index was called during initialization
self.client_mock.indices.exists.assert_called_once()
# Reset mock
self.client_mock.reset_mock()
# Test with auto_create_index=False
ElasticsearchDB(
host=os.getenv('ES_URL'),
host=os.getenv("ES_URL"),
port=9200,
collection_name="test_collection",
embedding_model_dims=1536,
user=os.getenv('ES_USERNAME'),
password=os.getenv('ES_PASSWORD'),
user=os.getenv("ES_USERNAME"),
password=os.getenv("ES_PASSWORD"),
verify_certs=False,
use_ssl=False,
auto_create_index=False
auto_create_index=False,
)
# Verify create_index was not called during initialization
self.client_mock.indices.exists.assert_not_called()
@@ -152,17 +150,17 @@ class TestElasticsearchDB(unittest.TestCase):
vectors = [[0.1] * 1536, [0.2] * 1536]
payloads = [{"key1": "value1"}, {"key2": "value2"}]
ids = ["id1", "id2"]
# Mock bulk operation
with patch('mem0.vector_stores.elasticsearch.bulk') as mock_bulk:
with patch("mem0.vector_stores.elasticsearch.bulk") as mock_bulk:
mock_bulk.return_value = (2, []) # Simulate successful bulk insert
# Perform insert
results = self.es_db.insert(vectors=vectors, payloads=payloads, ids=ids)
# Verify bulk was called
mock_bulk.assert_called_once()
# Verify bulk actions format
actions = mock_bulk.call_args[0][1]
self.assertEqual(len(actions), 2)
@@ -170,7 +168,7 @@ class TestElasticsearchDB(unittest.TestCase):
self.assertEqual(actions[0]["_id"], "id1")
self.assertEqual(actions[0]["_source"]["vector"], vectors[0])
self.assertEqual(actions[0]["_source"]["metadata"], payloads[0])
# Verify returned objects
self.assertEqual(len(results), 2)
self.assertIsInstance(results[0], OutputData)
@@ -182,14 +180,7 @@ class TestElasticsearchDB(unittest.TestCase):
mock_response = {
"hits": {
"hits": [
{
"_id": "id1",
"_score": 0.8,
"_source": {
"vector": [0.1] * 1536,
"metadata": {"key1": "value1"}
}
}
{"_id": "id1", "_score": 0.8, "_source": {"vector": [0.1] * 1536, "metadata": {"key1": "value1"}}}
]
}
}
@@ -206,7 +197,7 @@ class TestElasticsearchDB(unittest.TestCase):
# Verify search parameters
self.assertEqual(search_args["index"], "test_collection")
body = search_args["body"]
# Verify KNN query structure
self.assertIn("knn", body)
self.assertEqual(body["knn"]["field"], "vector")
@@ -235,29 +226,24 @@ class TestElasticsearchDB(unittest.TestCase):
self.es_db.custom_search_query.assert_called_once_with(vectors, limit, filters)
# Verify custom search query was used
self.client_mock.search.assert_called_once_with(index=self.es_db.collection_name, body={"custom_key": "custom_value"})
self.client_mock.search.assert_called_once_with(
index=self.es_db.collection_name, body={"custom_key": "custom_value"}
)
def test_get(self):
# Mock get response with correct structure
mock_response = {
"_id": "id1",
"_source": {
"vector": [0.1] * 1536,
"metadata": {"key": "value"},
"text": "sample text"
}
"_source": {"vector": [0.1] * 1536, "metadata": {"key": "value"}, "text": "sample text"},
}
self.client_mock.get.return_value = mock_response
# Perform get
result = self.es_db.get(vector_id="id1")
# Verify get call
self.client_mock.get.assert_called_once_with(
index="test_collection",
id="id1"
)
self.client_mock.get.assert_called_once_with(index="test_collection", id="id1")
# Verify result
self.assertIsNotNone(result)
self.assertEqual(result.id, "id1")
@@ -267,7 +253,7 @@ class TestElasticsearchDB(unittest.TestCase):
def test_get_not_found(self):
# Mock get raising exception
self.client_mock.get.side_effect = Exception("Not found")
# Verify get returns None when document not found
result = self.es_db.get(vector_id="nonexistent")
self.assertIsNone(result)
@@ -277,33 +263,19 @@ class TestElasticsearchDB(unittest.TestCase):
mock_response = {
"hits": {
"hits": [
{
"_id": "id1",
"_source": {
"vector": [0.1] * 1536,
"metadata": {"key1": "value1"}
},
"_score": 1.0
},
{
"_id": "id2",
"_source": {
"vector": [0.2] * 1536,
"metadata": {"key2": "value2"}
},
"_score": 0.8
}
{"_id": "id1", "_source": {"vector": [0.1] * 1536, "metadata": {"key1": "value1"}}, "_score": 1.0},
{"_id": "id2", "_source": {"vector": [0.2] * 1536, "metadata": {"key2": "value2"}}, "_score": 0.8},
]
}
}
self.client_mock.search.return_value = mock_response
# Perform list operation
results = self.es_db.list(limit=10)
# Verify search call
self.client_mock.search.assert_called_once()
# Verify results
self.assertEqual(len(results), 1) # Outer list
self.assertEqual(len(results[0]), 2) # Inner list
@@ -316,30 +288,24 @@ class TestElasticsearchDB(unittest.TestCase):
def test_delete(self):
# Perform delete
self.es_db.delete(vector_id="id1")
# Verify delete call
self.client_mock.delete.assert_called_once_with(
index="test_collection",
id="id1"
)
self.client_mock.delete.assert_called_once_with(index="test_collection", id="id1")
def test_list_cols(self):
# Mock indices response
mock_indices = {"index1": {}, "index2": {}}
self.client_mock.indices.get_alias.return_value = mock_indices
# Get collections
result = self.es_db.list_cols()
# Verify result
self.assertEqual(result, ["index1", "index2"])
def test_delete_col(self):
# Delete collection
self.es_db.delete_col()
# Verify delete call
self.client_mock.indices.delete.assert_called_once_with(
index="test_collection"
)
self.client_mock.indices.delete.assert_called_once_with(index="test_collection")

View File

@@ -21,9 +21,9 @@ def mock_faiss_index():
def faiss_instance(mock_faiss_index):
with tempfile.TemporaryDirectory() as temp_dir:
# Mock the faiss index creation
with patch('faiss.IndexFlatL2', return_value=mock_faiss_index):
with patch("faiss.IndexFlatL2", return_value=mock_faiss_index):
# Mock the faiss.write_index function
with patch('faiss.write_index'):
with patch("faiss.write_index"):
# Create a FAISS instance with a temporary directory
faiss_store = FAISS(
collection_name="test_collection",
@@ -37,14 +37,14 @@ def faiss_instance(mock_faiss_index):
def test_create_col(faiss_instance, mock_faiss_index):
# Test creating a collection with euclidean distance
with patch('faiss.IndexFlatL2', return_value=mock_faiss_index) as mock_index_flat_l2:
with patch('faiss.write_index'):
with patch("faiss.IndexFlatL2", return_value=mock_faiss_index) as mock_index_flat_l2:
with patch("faiss.write_index"):
faiss_instance.create_col(name="new_collection")
mock_index_flat_l2.assert_called_once_with(faiss_instance.embedding_model_dims)
# Test creating a collection with inner product distance
with patch('faiss.IndexFlatIP', return_value=mock_faiss_index) as mock_index_flat_ip:
with patch('faiss.write_index'):
with patch("faiss.IndexFlatIP", return_value=mock_faiss_index) as mock_index_flat_ip:
with patch("faiss.write_index"):
faiss_instance.create_col(name="new_collection", distance="inner_product")
mock_index_flat_ip.assert_called_once_with(faiss_instance.embedding_model_dims)
@@ -54,21 +54,21 @@ def test_insert(faiss_instance, mock_faiss_index):
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
payloads = [{"name": "vector1"}, {"name": "vector2"}]
ids = ["id1", "id2"]
# Mock the numpy array conversion
with patch('numpy.array', return_value=np.array(vectors, dtype=np.float32)) as mock_np_array:
with patch("numpy.array", return_value=np.array(vectors, dtype=np.float32)) as mock_np_array:
# Mock index.add
mock_faiss_index.add.return_value = None
# Call insert
faiss_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
# Verify numpy.array was called
mock_np_array.assert_called_once_with(vectors, dtype=np.float32)
# Verify index.add was called
mock_faiss_index.add.assert_called_once()
# Verify docstore and index_to_id were updated
assert faiss_instance.docstore["id1"] == {"name": "vector1"}
assert faiss_instance.docstore["id2"] == {"name": "vector2"}
@@ -79,39 +79,36 @@ def test_insert(faiss_instance, mock_faiss_index):
def test_search(faiss_instance, mock_faiss_index):
# Prepare test data
query_vector = [0.1, 0.2, 0.3]
# Setup the docstore and index_to_id mapping
faiss_instance.docstore = {
"id1": {"name": "vector1"},
"id2": {"name": "vector2"}
}
faiss_instance.docstore = {"id1": {"name": "vector1"}, "id2": {"name": "vector2"}}
faiss_instance.index_to_id = {0: "id1", 1: "id2"}
# First, create the mock for the search return values
search_scores = np.array([[0.9, 0.8]])
search_indices = np.array([[0, 1]])
mock_faiss_index.search.return_value = (search_scores, search_indices)
# Then patch numpy.array only for the query vector conversion
with patch('numpy.array') as mock_np_array:
with patch("numpy.array") as mock_np_array:
mock_np_array.return_value = np.array(query_vector, dtype=np.float32)
# Then patch _parse_output to return the expected results
expected_results = [
OutputData(id="id1", score=0.9, payload={"name": "vector1"}),
OutputData(id="id2", score=0.8, payload={"name": "vector2"})
OutputData(id="id2", score=0.8, payload={"name": "vector2"}),
]
with patch.object(faiss_instance, '_parse_output', return_value=expected_results):
with patch.object(faiss_instance, "_parse_output", return_value=expected_results):
# Call search
results = faiss_instance.search(query="test query", vectors=query_vector, limit=2)
# Verify numpy.array was called (but we don't check exact call arguments since it's complex)
assert mock_np_array.called
# Verify index.search was called
mock_faiss_index.search.assert_called_once()
# Verify results
assert len(results) == 2
assert results[0].id == "id1"
@@ -125,47 +122,41 @@ def test_search(faiss_instance, mock_faiss_index):
def test_search_with_filters(faiss_instance, mock_faiss_index):
# Prepare test data
query_vector = [0.1, 0.2, 0.3]
# Setup the docstore and index_to_id mapping
faiss_instance.docstore = {
"id1": {"name": "vector1", "category": "A"},
"id2": {"name": "vector2", "category": "B"}
}
faiss_instance.docstore = {"id1": {"name": "vector1", "category": "A"}, "id2": {"name": "vector2", "category": "B"}}
faiss_instance.index_to_id = {0: "id1", 1: "id2"}
# First set up the search return values
search_scores = np.array([[0.9, 0.8]])
search_indices = np.array([[0, 1]])
mock_faiss_index.search.return_value = (search_scores, search_indices)
# Patch numpy.array for query vector conversion
with patch('numpy.array') as mock_np_array:
with patch("numpy.array") as mock_np_array:
mock_np_array.return_value = np.array(query_vector, dtype=np.float32)
# Directly mock the _parse_output method to return our expected values
# We're simulating that _parse_output filters to just the first result
all_results = [
OutputData(id="id1", score=0.9, payload={"name": "vector1", "category": "A"}),
OutputData(id="id2", score=0.8, payload={"name": "vector2", "category": "B"})
OutputData(id="id2", score=0.8, payload={"name": "vector2", "category": "B"}),
]
# Replace the _apply_filters method to handle our test case
with patch.object(faiss_instance, '_parse_output', return_value=all_results):
with patch.object(faiss_instance, '_apply_filters', side_effect=lambda p, f: p.get("category") == "A"):
with patch.object(faiss_instance, "_parse_output", return_value=all_results):
with patch.object(faiss_instance, "_apply_filters", side_effect=lambda p, f: p.get("category") == "A"):
# Call search with filters
results = faiss_instance.search(
query="test query",
vectors=query_vector,
limit=2,
filters={"category": "A"}
query="test query", vectors=query_vector, limit=2, filters={"category": "A"}
)
# Verify numpy.array was called
assert mock_np_array.called
# Verify index.search was called
mock_faiss_index.search.assert_called_once()
# Verify filtered results - since we've mocked everything,
# we should get just the result we want
assert len(results) == 1
@@ -176,15 +167,12 @@ def test_search_with_filters(faiss_instance, mock_faiss_index):
def test_delete(faiss_instance):
# Setup the docstore and index_to_id mapping
faiss_instance.docstore = {
"id1": {"name": "vector1"},
"id2": {"name": "vector2"}
}
faiss_instance.docstore = {"id1": {"name": "vector1"}, "id2": {"name": "vector2"}}
faiss_instance.index_to_id = {0: "id1", 1: "id2"}
# Call delete
faiss_instance.delete(vector_id="id1")
# Verify the vector was removed from docstore and index_to_id
assert "id1" not in faiss_instance.docstore
assert 0 not in faiss_instance.index_to_id
@@ -194,23 +182,20 @@ def test_delete(faiss_instance):
def test_update(faiss_instance, mock_faiss_index):
# Setup the docstore and index_to_id mapping
faiss_instance.docstore = {
"id1": {"name": "vector1"},
"id2": {"name": "vector2"}
}
faiss_instance.docstore = {"id1": {"name": "vector1"}, "id2": {"name": "vector2"}}
faiss_instance.index_to_id = {0: "id1", 1: "id2"}
# Test updating payload only
faiss_instance.update(vector_id="id1", payload={"name": "updated_vector1"})
assert faiss_instance.docstore["id1"] == {"name": "updated_vector1"}
# Test updating vector
# This requires mocking the delete and insert methods
with patch.object(faiss_instance, 'delete') as mock_delete:
with patch.object(faiss_instance, 'insert') as mock_insert:
with patch.object(faiss_instance, "delete") as mock_delete:
with patch.object(faiss_instance, "insert") as mock_insert:
new_vector = [0.7, 0.8, 0.9]
faiss_instance.update(vector_id="id2", vector=new_vector)
# Verify delete and insert were called
# Match the actual call signature (positional arg instead of keyword)
mock_delete.assert_called_once_with("id2")
@@ -219,17 +204,14 @@ def test_update(faiss_instance, mock_faiss_index):
def test_get(faiss_instance):
# Setup the docstore
faiss_instance.docstore = {
"id1": {"name": "vector1"},
"id2": {"name": "vector2"}
}
faiss_instance.docstore = {"id1": {"name": "vector1"}, "id2": {"name": "vector2"}}
# Test getting an existing vector
result = faiss_instance.get(vector_id="id1")
assert result.id == "id1"
assert result.payload == {"name": "vector1"}
assert result.score is None
# Test getting a non-existent vector
result = faiss_instance.get(vector_id="id3")
assert result is None
@@ -240,18 +222,18 @@ def test_list(faiss_instance):
faiss_instance.docstore = {
"id1": {"name": "vector1", "category": "A"},
"id2": {"name": "vector2", "category": "B"},
"id3": {"name": "vector3", "category": "A"}
"id3": {"name": "vector3", "category": "A"},
}
# Test listing all vectors
results = faiss_instance.list()
# Fix the expected result - the list method returns a list of lists
assert len(results[0]) == 3
# Test listing with a limit
results = faiss_instance.list(limit=2)
assert len(results[0]) == 2
# Test listing with filters
results = faiss_instance.list(filters={"category": "A"})
assert len(results[0]) == 2
@@ -263,10 +245,10 @@ def test_col_info(faiss_instance, mock_faiss_index):
# Mock index attributes
mock_faiss_index.ntotal = 5
mock_faiss_index.d = 128
# Get collection info
info = faiss_instance.col_info()
# Verify the returned info
assert info["name"] == "test_collection"
assert info["count"] == 5
@@ -276,14 +258,14 @@ def test_col_info(faiss_instance, mock_faiss_index):
def test_delete_col(faiss_instance):
# Mock the os.remove function
with patch('os.remove') as mock_remove:
with patch('os.path.exists', return_value=True):
with patch("os.remove") as mock_remove:
with patch("os.path.exists", return_value=True):
# Call delete_col
faiss_instance.delete_col()
# Verify os.remove was called twice (for index and docstore files)
assert mock_remove.call_count == 2
# Verify the internal state was reset
assert faiss_instance.index is None
assert faiss_instance.docstore == {}
@@ -293,17 +275,17 @@ def test_delete_col(faiss_instance):
def test_normalize_L2(faiss_instance, mock_faiss_index):
# Setup a FAISS instance with normalize_L2=True
faiss_instance.normalize_L2 = True
# Prepare test data
vectors = [[0.1, 0.2, 0.3]]
# Mock numpy array conversion
# Mock numpy array conversion
with patch('numpy.array', return_value=np.array(vectors, dtype=np.float32)):
with patch("numpy.array", return_value=np.array(vectors, dtype=np.float32)):
# Mock faiss.normalize_L2
with patch('faiss.normalize_L2') as mock_normalize:
with patch("faiss.normalize_L2") as mock_normalize:
# Call insert
faiss_instance.insert(vectors=vectors, ids=["id1"])
# Verify faiss.normalize_L2 was called
mock_normalize.assert_called_once()

View File

@@ -11,11 +11,13 @@ def mock_langchain_client():
with patch("langchain_community.vectorstores.VectorStore") as mock_client:
yield mock_client
@pytest.fixture
def langchain_instance(mock_langchain_client):
mock_client = Mock(spec=VectorStore)
return Langchain(client=mock_client, collection_name="test_collection")
def test_insert_vectors(langchain_instance):
# Test data
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
@@ -25,48 +27,31 @@ def test_insert_vectors(langchain_instance):
# Test with add_embeddings method
langchain_instance.client.add_embeddings = Mock()
langchain_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
langchain_instance.client.add_embeddings.assert_called_once_with(
embeddings=vectors,
metadatas=payloads,
ids=ids
)
langchain_instance.client.add_embeddings.assert_called_once_with(embeddings=vectors, metadatas=payloads, ids=ids)
# Test with add_texts method
delattr(langchain_instance.client, "add_embeddings") # Remove attribute completely
langchain_instance.client.add_texts = Mock()
langchain_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
langchain_instance.client.add_texts.assert_called_once_with(
texts=["text1", "text2"],
metadatas=payloads,
ids=ids
)
langchain_instance.client.add_texts.assert_called_once_with(texts=["text1", "text2"], metadatas=payloads, ids=ids)
# Test with empty payloads
langchain_instance.client.add_texts.reset_mock()
langchain_instance.insert(vectors=vectors, payloads=None, ids=ids)
langchain_instance.client.add_texts.assert_called_once_with(
texts=["", ""],
metadatas=None,
ids=ids
)
langchain_instance.client.add_texts.assert_called_once_with(texts=["", ""], metadatas=None, ids=ids)
def test_search_vectors(langchain_instance):
# Mock search results
mock_docs = [
Mock(metadata={"name": "vector1"}, id="id1"),
Mock(metadata={"name": "vector2"}, id="id2")
]
mock_docs = [Mock(metadata={"name": "vector1"}, id="id1"), Mock(metadata={"name": "vector2"}, id="id2")]
langchain_instance.client.similarity_search_by_vector.return_value = mock_docs
# Test search without filters
vectors = [[0.1, 0.2, 0.3]]
results = langchain_instance.search(query="", vectors=vectors, limit=2)
langchain_instance.client.similarity_search_by_vector.assert_called_once_with(
embedding=vectors,
k=2
)
langchain_instance.client.similarity_search_by_vector.assert_called_once_with(embedding=vectors, k=2)
assert len(results) == 2
assert results[0].id == "id1"
assert results[0].payload == {"name": "vector1"}
@@ -76,11 +61,8 @@ def test_search_vectors(langchain_instance):
# Test search with filters
filters = {"name": "vector1"}
langchain_instance.search(query="", vectors=vectors, limit=2, filters=filters)
langchain_instance.client.similarity_search_by_vector.assert_called_with(
embedding=vectors,
k=2,
filter=filters
)
langchain_instance.client.similarity_search_by_vector.assert_called_with(embedding=vectors, k=2, filter=filters)
def test_get_vector(langchain_instance):
# Mock get result
@@ -90,7 +72,7 @@ def test_get_vector(langchain_instance):
# Test get existing vector
result = langchain_instance.get("id1")
langchain_instance.client.get_by_ids.assert_called_once_with(["id1"])
assert result is not None
assert result.id == "id1"
assert result.payload == {"name": "vector1"}

View File

@@ -8,9 +8,7 @@ import pytest
try:
from opensearchpy import AWSV4SignerAuth, OpenSearch
except ImportError:
raise ImportError(
"OpenSearch requires extra dependencies. Install with `pip install opensearch-py`"
) from None
raise ImportError("OpenSearch requires extra dependencies. Install with `pip install opensearch-py`") from None
from mem0.vector_stores.opensearch import OpenSearchDB
@@ -20,13 +18,13 @@ class TestOpenSearchDB(unittest.TestCase):
def setUpClass(cls):
dotenv.load_dotenv()
cls.original_env = {
'OS_URL': os.getenv('OS_URL', 'http://localhost:9200'),
'OS_USERNAME': os.getenv('OS_USERNAME', 'test_user'),
'OS_PASSWORD': os.getenv('OS_PASSWORD', 'test_password')
"OS_URL": os.getenv("OS_URL", "http://localhost:9200"),
"OS_USERNAME": os.getenv("OS_USERNAME", "test_user"),
"OS_PASSWORD": os.getenv("OS_PASSWORD", "test_password"),
}
os.environ['OS_URL'] = 'http://localhost'
os.environ['OS_USERNAME'] = 'test_user'
os.environ['OS_PASSWORD'] = 'test_password'
os.environ["OS_URL"] = "http://localhost"
os.environ["OS_USERNAME"] = "test_user"
os.environ["OS_PASSWORD"] = "test_password"
def setUp(self):
self.client_mock = MagicMock(spec=OpenSearch)
@@ -40,19 +38,19 @@ class TestOpenSearchDB(unittest.TestCase):
self.client_mock.delete = MagicMock()
self.client_mock.search = MagicMock()
patcher = patch('mem0.vector_stores.opensearch.OpenSearch', return_value=self.client_mock)
patcher = patch("mem0.vector_stores.opensearch.OpenSearch", return_value=self.client_mock)
self.mock_os = patcher.start()
self.addCleanup(patcher.stop)
self.os_db = OpenSearchDB(
host=os.getenv('OS_URL'),
host=os.getenv("OS_URL"),
port=9200,
collection_name="test_collection",
embedding_model_dims=1536,
user=os.getenv('OS_USERNAME'),
password=os.getenv('OS_PASSWORD'),
user=os.getenv("OS_USERNAME"),
password=os.getenv("OS_PASSWORD"),
verify_certs=False,
use_ssl=False
use_ssl=False,
)
self.client_mock.reset_mock()
@@ -86,29 +84,29 @@ class TestOpenSearchDB(unittest.TestCase):
vectors = [[0.1] * 1536, [0.2] * 1536]
payloads = [{"key1": "value1"}, {"key2": "value2"}]
ids = ["id1", "id2"]
# Mock the index method
self.client_mock.index = MagicMock()
results = self.os_db.insert(vectors=vectors, payloads=payloads, ids=ids)
# Verify index was called twice (once for each vector)
self.assertEqual(self.client_mock.index.call_count, 2)
# Check first call
first_call = self.client_mock.index.call_args_list[0]
self.assertEqual(first_call[1]["index"], "test_collection")
self.assertEqual(first_call[1]["body"]["vector_field"], vectors[0])
self.assertEqual(first_call[1]["body"]["payload"], payloads[0])
self.assertEqual(first_call[1]["body"]["id"], ids[0])
# Check second call
second_call = self.client_mock.index.call_args_list[1]
self.assertEqual(second_call[1]["index"], "test_collection")
self.assertEqual(second_call[1]["body"]["vector_field"], vectors[1])
self.assertEqual(second_call[1]["body"]["payload"], payloads[1])
self.assertEqual(second_call[1]["body"]["id"], ids[1])
# Check results
self.assertEqual(len(results), 2)
self.assertEqual(results[0].id, "id1")
@@ -132,7 +130,7 @@ class TestOpenSearchDB(unittest.TestCase):
self.client_mock.search.return_value = {"hits": {"hits": []}}
result = self.os_db.get("nonexistent")
self.assertIsNone(result)
def test_update(self):
vector = [0.3] * 1536
payload = {"key3": "value3"}
@@ -152,7 +150,17 @@ class TestOpenSearchDB(unittest.TestCase):
self.assertEqual(result, ["test_collection"])
def test_search(self):
mock_response = {"hits": {"hits": [{"_id": "id1", "_score": 0.8, "_source": {"vector_field": [0.1] * 1536, "id": "id1", "payload": {"key1": "value1"}}}]}}
mock_response = {
"hits": {
"hits": [
{
"_id": "id1",
"_score": 0.8,
"_source": {"vector_field": [0.1] * 1536, "id": "id1", "payload": {"key1": "value1"}},
}
]
}
}
self.client_mock.search.return_value = mock_response
vectors = [[0.1] * 1536]
results = self.os_db.search(query="", vectors=vectors, limit=5)
@@ -179,12 +187,11 @@ class TestOpenSearchDB(unittest.TestCase):
self.os_db.delete_col()
self.client_mock.indices.delete.assert_called_once_with(index="test_collection")
def test_init_with_http_auth(self):
mock_credentials = MagicMock()
mock_signer = AWSV4SignerAuth(mock_credentials, "us-east-1", "es")
with patch('mem0.vector_stores.opensearch.OpenSearch') as mock_opensearch:
with patch("mem0.vector_stores.opensearch.OpenSearch") as mock_opensearch:
OpenSearchDB(
host="localhost",
port=9200,
@@ -192,7 +199,7 @@ class TestOpenSearchDB(unittest.TestCase):
embedding_model_dims=1536,
http_auth=mock_signer,
verify_certs=True,
use_ssl=True
use_ssl=True,
)
# Verify OpenSearch was initialized with correct params
@@ -202,5 +209,5 @@ class TestOpenSearchDB(unittest.TestCase):
use_ssl=True,
verify_certs=True,
connection_class=unittest.mock.ANY,
pool_maxsize=20
)
pool_maxsize=20,
)

View File

@@ -12,6 +12,7 @@ def mock_pinecone_client():
client.list_indexes.return_value.names.return_value = []
return client
@pytest.fixture
def pinecone_db(mock_pinecone_client):
return PineconeDB(
@@ -25,13 +26,14 @@ def pinecone_db(mock_pinecone_client):
hybrid_search=False,
metric="cosine",
batch_size=100,
extra_params=None
extra_params=None,
)
def test_create_col_existing_index(mock_pinecone_client):
# Set up the mock before creating the PineconeDB object
mock_pinecone_client.list_indexes.return_value.names.return_value = ["test_index"]
pinecone_db = PineconeDB(
collection_name="test_index",
embedding_model_dims=128,
@@ -43,21 +45,23 @@ def test_create_col_existing_index(mock_pinecone_client):
hybrid_search=False,
metric="cosine",
batch_size=100,
extra_params=None
extra_params=None,
)
# Reset the mock to verify it wasn't called during the test
mock_pinecone_client.create_index.reset_mock()
pinecone_db.create_col(128, "cosine")
mock_pinecone_client.create_index.assert_not_called()
def test_create_col_new_index(pinecone_db, mock_pinecone_client):
mock_pinecone_client.list_indexes.return_value.names.return_value = []
pinecone_db.create_col(128, "cosine")
mock_pinecone_client.create_index.assert_called()
def test_insert_vectors(pinecone_db):
vectors = [[0.1] * 128, [0.2] * 128]
payloads = [{"name": "vector1"}, {"name": "vector2"}]
@@ -65,56 +69,61 @@ def test_insert_vectors(pinecone_db):
pinecone_db.insert(vectors, payloads, ids)
pinecone_db.index.upsert.assert_called()
def test_search_vectors(pinecone_db):
pinecone_db.index.query.return_value.matches = [{"id": "id1", "score": 0.9, "metadata": {"name": "vector1"}}]
results = pinecone_db.search("test query",[0.1] * 128, limit=1)
results = pinecone_db.search("test query", [0.1] * 128, limit=1)
assert len(results) == 1
assert results[0].id == "id1"
assert results[0].score == 0.9
def test_update_vector(pinecone_db):
pinecone_db.update("id1", vector=[0.5] * 128, payload={"name": "updated"})
pinecone_db.index.upsert.assert_called()
def test_get_vector_found(pinecone_db):
# Looking at the _parse_output method, it expects a Vector object
# or a list of dictionaries, not a dictionary with an 'id' field
# Create a mock Vector object
from pinecone.data.dataclasses.vector import Vector
mock_vector = Vector(
id="id1",
values=[0.1] * 128,
metadata={"name": "vector1"}
)
mock_vector = Vector(id="id1", values=[0.1] * 128, metadata={"name": "vector1"})
# Mock the fetch method to return the mock response object
mock_response = MagicMock()
mock_response.vectors = {"id1": mock_vector}
pinecone_db.index.fetch.return_value = mock_response
result = pinecone_db.get("id1")
assert result is not None
assert result.id == "id1"
assert result.payload == {"name": "vector1"}
def test_delete_vector(pinecone_db):
pinecone_db.delete("id1")
pinecone_db.index.delete.assert_called_with(ids=["id1"])
def test_get_vector_not_found(pinecone_db):
pinecone_db.index.fetch.return_value.vectors = {}
result = pinecone_db.get("id1")
assert result is None
def test_list_cols(pinecone_db):
pinecone_db.list_cols()
pinecone_db.client.list_indexes.assert_called()
def test_delete_col(pinecone_db):
pinecone_db.delete_col()
pinecone_db.client.delete_index.assert_called_with("test_index")
def test_col_info(pinecone_db):
pinecone_db.col_info()
pinecone_db.client.describe_index.assert_called_with("test_index")

View File

@@ -37,7 +37,7 @@ def supabase_instance(mock_vecs_client, mock_collection):
index_method=IndexMethod.HNSW,
index_measure=IndexMeasure.COSINE,
)
# Manually set the collection attribute since we're mocking the initialization
instance.collection = mock_collection
return instance
@@ -46,14 +46,8 @@ def supabase_instance(mock_vecs_client, mock_collection):
def test_create_col(supabase_instance, mock_vecs_client, mock_collection):
supabase_instance.create_col(1536)
mock_vecs_client.return_value.get_or_create_collection.assert_called_with(
name="test_collection",
dimension=1536
)
mock_collection.create_index.assert_called_with(
method="hnsw",
measure="cosine_distance"
)
mock_vecs_client.return_value.get_or_create_collection.assert_called_with(name="test_collection", dimension=1536)
mock_collection.create_index.assert_called_with(method="hnsw", measure="cosine_distance")
def test_insert_vectors(supabase_instance, mock_collection):
@@ -63,18 +57,12 @@ def test_insert_vectors(supabase_instance, mock_collection):
supabase_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
expected_records = [
("id1", [0.1, 0.2, 0.3], {"name": "vector1"}),
("id2", [0.4, 0.5, 0.6], {"name": "vector2"})
]
expected_records = [("id1", [0.1, 0.2, 0.3], {"name": "vector1"}), ("id2", [0.4, 0.5, 0.6], {"name": "vector2"})]
mock_collection.upsert.assert_called_once_with(expected_records)
def test_search_vectors(supabase_instance, mock_collection):
mock_results = [
("id1", 0.9, {"name": "vector1"}),
("id2", 0.8, {"name": "vector2"})
]
mock_results = [("id1", 0.9, {"name": "vector1"}), ("id2", 0.8, {"name": "vector2"})]
mock_collection.query.return_value = mock_results
vectors = [[0.1, 0.2, 0.3]]
@@ -82,11 +70,7 @@ def test_search_vectors(supabase_instance, mock_collection):
results = supabase_instance.search(query="", vectors=vectors, limit=2, filters=filters)
mock_collection.query.assert_called_once_with(
data=vectors,
limit=2,
filters={"category": {"$eq": "test"}},
include_metadata=True,
include_value=True
data=vectors, limit=2, filters={"category": {"$eq": "test"}}, include_metadata=True, include_value=True
)
assert len(results) == 2
@@ -129,11 +113,8 @@ def test_get_vector(supabase_instance, mock_collection):
def test_list_vectors(supabase_instance, mock_collection):
mock_query_results = [("id1", 0.9, {}), ("id2", 0.8, {})]
mock_fetch_results = [
("id1", [0.1, 0.2, 0.3], {"name": "vector1"}),
("id2", [0.4, 0.5, 0.6], {"name": "vector2"})
]
mock_fetch_results = [("id1", [0.1, 0.2, 0.3], {"name": "vector1"}), ("id2", [0.4, 0.5, 0.6], {"name": "vector2"})]
mock_collection.query.return_value = mock_query_results
mock_collection.fetch.return_value = mock_fetch_results
@@ -153,10 +134,7 @@ def test_col_info(supabase_instance, mock_collection):
"name": "test_collection",
"count": 100,
"dimension": 1536,
"index": {
"method": "hnsw",
"metric": "cosine_distance"
}
"index": {"method": "hnsw", "metric": "cosine_distance"},
}
@@ -168,10 +146,7 @@ def test_preprocess_filters(supabase_instance):
# Test multiple filters
multi_filter = {"category": "test", "type": "document"}
assert supabase_instance._preprocess_filters(multi_filter) == {
"$and": [
{"category": {"$eq": "test"}},
{"type": {"$eq": "document"}}
]
"$and": [{"category": {"$eq": "test"}}, {"type": {"$eq": "document"}}]
}
# Test None filters

View File

@@ -29,9 +29,7 @@ def upstash_instance(mock_index):
@pytest.fixture
def upstash_instance_with_embeddings(mock_index):
return UpstashVector(
client=mock_index.return_value, collection_name="ns", enable_embeddings=True
)
return UpstashVector(client=mock_index.return_value, collection_name="ns", enable_embeddings=True)
def test_insert_vectors(upstash_instance, mock_index):
@@ -52,12 +50,8 @@ def test_insert_vectors(upstash_instance, mock_index):
def test_search_vectors(upstash_instance, mock_index):
mock_result = [
QueryResult(
id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None
),
QueryResult(
id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data=None
),
QueryResult(id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None),
QueryResult(id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data=None),
]
upstash_instance.client.query_many.return_value = [mock_result]
@@ -93,9 +87,7 @@ def test_delete_vector(upstash_instance):
upstash_instance.delete(vector_id=vector_id)
upstash_instance.client.delete.assert_called_once_with(
ids=[vector_id], namespace="ns"
)
upstash_instance.client.delete.assert_called_once_with(ids=[vector_id], namespace="ns")
def test_update_vector(upstash_instance):
@@ -115,18 +107,12 @@ def test_update_vector(upstash_instance):
def test_get_vector(upstash_instance):
mock_result = [
QueryResult(
id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None
)
]
mock_result = [QueryResult(id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None)]
upstash_instance.client.fetch.return_value = mock_result
result = upstash_instance.get(vector_id="id1")
upstash_instance.client.fetch.assert_called_once_with(
ids=["id1"], namespace="ns", include_metadata=True
)
upstash_instance.client.fetch.assert_called_once_with(ids=["id1"], namespace="ns", include_metadata=True)
assert result.id == "id1"
assert result.payload == {"name": "vector1"}
@@ -134,15 +120,9 @@ def test_get_vector(upstash_instance):
def test_list_vectors(upstash_instance):
mock_result = [
QueryResult(
id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None
),
QueryResult(
id="id2", score=None, vector=None, metadata={"name": "vector2"}, data=None
),
QueryResult(
id="id3", score=None, vector=None, metadata={"name": "vector3"}, data=None
),
QueryResult(id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None),
QueryResult(id="id2", score=None, vector=None, metadata={"name": "vector2"}, data=None),
QueryResult(id="id3", score=None, vector=None, metadata={"name": "vector3"}, data=None),
]
handler = MagicMock()
@@ -204,12 +184,8 @@ def test_insert_vectors_with_embeddings(upstash_instance_with_embeddings, mock_i
def test_search_vectors_with_embeddings(upstash_instance_with_embeddings, mock_index):
mock_result = [
QueryResult(
id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data="data1"
),
QueryResult(
id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data="data2"
),
QueryResult(id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data="data1"),
QueryResult(id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data="data2"),
]
upstash_instance_with_embeddings.client.query.return_value = mock_result
@@ -260,9 +236,7 @@ def test_insert_vectors_with_embeddings_missing_data(upstash_instance_with_embed
ValueError,
match="When embeddings are enabled, all payloads must contain a 'data' field",
):
upstash_instance_with_embeddings.insert(
vectors=vectors, payloads=payloads, ids=ids
)
upstash_instance_with_embeddings.insert(vectors=vectors, payloads=payloads, ids=ids)
def test_update_vector_with_embeddings_missing_data(upstash_instance_with_embeddings):
@@ -316,18 +290,12 @@ def test_get_vector_not_found(upstash_instance):
result = upstash_instance.get(vector_id="nonexistent")
upstash_instance.client.fetch.assert_called_once_with(
ids=["nonexistent"], namespace="ns", include_metadata=True
)
upstash_instance.client.fetch.assert_called_once_with(ids=["nonexistent"], namespace="ns", include_metadata=True)
assert result is None
def test_search_vectors_empty_filters(upstash_instance):
mock_result = [
QueryResult(
id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None
)
]
mock_result = [QueryResult(id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None)]
upstash_instance.client.query_many.return_value = [mock_result]
vectors = [[0.1, 0.2, 0.3]]

View File

@@ -14,47 +14,50 @@ from mem0.vector_stores.vertex_ai_vector_search import GoogleMatchingEngine
@pytest.fixture
def mock_vertex_ai():
with patch('google.cloud.aiplatform.MatchingEngineIndex') as mock_index, \
patch('google.cloud.aiplatform.MatchingEngineIndexEndpoint') as mock_endpoint, \
patch('google.cloud.aiplatform.init') as mock_init:
with (
patch("google.cloud.aiplatform.MatchingEngineIndex") as mock_index,
patch("google.cloud.aiplatform.MatchingEngineIndexEndpoint") as mock_endpoint,
patch("google.cloud.aiplatform.init") as mock_init,
):
mock_index_instance = Mock()
mock_endpoint_instance = Mock()
yield {
'index': mock_index_instance,
'endpoint': mock_endpoint_instance,
'init': mock_init,
'index_class': mock_index,
'endpoint_class': mock_endpoint
"index": mock_index_instance,
"endpoint": mock_endpoint_instance,
"init": mock_init,
"index_class": mock_index,
"endpoint_class": mock_endpoint,
}
@pytest.fixture
def config():
return GoogleMatchingEngineConfig(
project_id='test-project',
project_number='123456789',
region='us-central1',
endpoint_id='test-endpoint',
index_id='test-index',
deployment_index_id='test-deployment',
collection_name='test-collection',
vector_search_api_endpoint='test.vertexai.goog'
project_id="test-project",
project_number="123456789",
region="us-central1",
endpoint_id="test-endpoint",
index_id="test-index",
deployment_index_id="test-deployment",
collection_name="test-collection",
vector_search_api_endpoint="test.vertexai.goog",
)
@pytest.fixture
def vector_store(config, mock_vertex_ai):
mock_vertex_ai['index_class'].return_value = mock_vertex_ai['index']
mock_vertex_ai['endpoint_class'].return_value = mock_vertex_ai['endpoint']
mock_vertex_ai["index_class"].return_value = mock_vertex_ai["index"]
mock_vertex_ai["endpoint_class"].return_value = mock_vertex_ai["endpoint"]
return GoogleMatchingEngine(**config.model_dump())
def test_initialization(vector_store, mock_vertex_ai, config):
"""Test proper initialization of GoogleMatchingEngine"""
mock_vertex_ai['init'].assert_called_once_with(
project=config.project_id,
location=config.region
)
mock_vertex_ai["init"].assert_called_once_with(project=config.project_id, location=config.region)
expected_index_path = f"projects/{config.project_number}/locations/{config.region}/indexes/{config.index_id}"
mock_vertex_ai['index_class'].assert_called_once_with(index_name=expected_index_path)
mock_vertex_ai["index_class"].assert_called_once_with(index_name=expected_index_path)
def test_insert_vectors(vector_store, mock_vertex_ai):
"""Test inserting vectors with payloads"""
@@ -64,13 +67,14 @@ def test_insert_vectors(vector_store, mock_vertex_ai):
vector_store.insert(vectors=vectors, payloads=payloads, ids=ids)
mock_vertex_ai['index'].upsert_datapoints.assert_called_once()
call_args = mock_vertex_ai['index'].upsert_datapoints.call_args[1]
assert len(call_args['datapoints']) == 1
datapoint_str = str(call_args['datapoints'][0])
mock_vertex_ai["index"].upsert_datapoints.assert_called_once()
call_args = mock_vertex_ai["index"].upsert_datapoints.call_args[1]
assert len(call_args["datapoints"]) == 1
datapoint_str = str(call_args["datapoints"][0])
assert "test-id" in datapoint_str
assert "0.1" in datapoint_str and "0.2" in datapoint_str and "0.3" in datapoint_str
def test_search_vectors(vector_store, mock_vertex_ai):
"""Test searching vectors with filters"""
vectors = [[0.1, 0.2, 0.3]]
@@ -85,7 +89,7 @@ def test_search_vectors(vector_store, mock_vertex_ai):
mock_restrict.allow_list = ["test_user"]
mock_restrict.name = "user_id"
mock_restrict.allow_tokens = ["test_user"]
mock_datapoint.restricts = [mock_restrict]
mock_neighbor = Mock()
@@ -94,16 +98,16 @@ def test_search_vectors(vector_store, mock_vertex_ai):
mock_neighbor.datapoint = mock_datapoint
mock_neighbor.restricts = [mock_restrict]
mock_vertex_ai['endpoint'].find_neighbors.return_value = [[mock_neighbor]]
mock_vertex_ai["endpoint"].find_neighbors.return_value = [[mock_neighbor]]
results = vector_store.search(query="", vectors=vectors, filters=filters, limit=1)
mock_vertex_ai['endpoint'].find_neighbors.assert_called_once_with(
mock_vertex_ai["endpoint"].find_neighbors.assert_called_once_with(
deployed_index_id=vector_store.deployment_index_id,
queries=[vectors],
num_neighbors=1,
filter=[Namespace("user_id", ["test_user"], [])],
return_full_datapoint=True
return_full_datapoint=True,
)
assert len(results) == 1
@@ -111,29 +115,27 @@ def test_search_vectors(vector_store, mock_vertex_ai):
assert results[0].score == 0.1
assert results[0].payload == {"user_id": "test_user"}
def test_delete(vector_store, mock_vertex_ai):
"""Test deleting vectors"""
vector_id = "test-id"
remove_mock = Mock()
with patch.object(GoogleMatchingEngine, 'delete', wraps=vector_store.delete) as delete_spy:
with patch.object(vector_store.index, 'remove_datapoints', remove_mock):
with patch.object(GoogleMatchingEngine, "delete", wraps=vector_store.delete) as delete_spy:
with patch.object(vector_store.index, "remove_datapoints", remove_mock):
vector_store.delete(ids=[vector_id])
delete_spy.assert_called_once_with(ids=[vector_id])
remove_mock.assert_called_once_with(datapoint_ids=[vector_id])
def test_error_handling(vector_store, mock_vertex_ai):
"""Test error handling during operations"""
mock_vertex_ai['index'].upsert_datapoints.side_effect = exceptions.InvalidArgument("Invalid request")
mock_vertex_ai["index"].upsert_datapoints.side_effect = exceptions.InvalidArgument("Invalid request")
with pytest.raises(Exception) as exc_info:
vector_store.insert(
vectors=[[0.1, 0.2, 0.3]],
payloads=[{"name": "test"}],
ids=["test-id"]
)
vector_store.insert(vectors=[[0.1, 0.2, 0.3]], payloads=[{"name": "test"}], ids=["test-id"])
assert isinstance(exc_info.value, exceptions.InvalidArgument)
assert "Invalid request" in str(exc_info.value)

View File

@@ -76,15 +76,15 @@
# self.client_mock.batch = MagicMock()
# self.client_mock.batch.fixed_size.return_value.__enter__.return_value = MagicMock()
# self.client_mock.collections.get.return_value.data.insert_many.return_value = {
# "results": [{"id": "id1"}, {"id": "id2"}]
# }
# vectors = [[0.1] * 1536, [0.2] * 1536]
# payloads = [{"key1": "value1"}, {"key2": "value2"}]
# ids = [str(uuid.uuid4()), str(uuid.uuid4())]
# results = self.weaviate_db.insert(vectors=vectors, payloads=payloads, ids=ids)
# def test_get(self):
@@ -108,7 +108,7 @@
# result = self.weaviate_db.get(vector_id=valid_uuid)
# assert result.id == valid_uuid
# expected_payload = mock_response.properties.copy()
# expected_payload["id"] = valid_uuid
@@ -131,10 +131,10 @@
# "metadata": {"distance": 0.2}
# }
# ]
# mock_response = MagicMock()
# mock_response.objects = []
# for obj in mock_objects:
# mock_obj = MagicMock()
# mock_obj.uuid = obj["uuid"]
@@ -142,16 +142,16 @@
# mock_obj.metadata = MagicMock()
# mock_obj.metadata.distance = obj["metadata"]["distance"]
# mock_response.objects.append(mock_obj)
# mock_hybrid = MagicMock()
# self.client_mock.collections.get.return_value.query.hybrid = mock_hybrid
# mock_hybrid.return_value = mock_response
# vectors = [[0.1] * 1536]
# results = self.weaviate_db.search(query="", vectors=vectors, limit=5)
# mock_hybrid.assert_called_once()
# self.assertEqual(len(results), 1)
# self.assertEqual(results[0].id, "id1")
# self.assertEqual(results[0].score, 0.8)
@@ -163,28 +163,28 @@
# def test_list(self):
# mock_objects = []
# mock_obj1 = MagicMock()
# mock_obj1.uuid = "id1"
# mock_obj1.properties = {"key1": "value1"}
# mock_objects.append(mock_obj1)
# mock_obj2 = MagicMock()
# mock_obj2.uuid = "id2"
# mock_obj2.properties = {"key2": "value2"}
# mock_objects.append(mock_obj2)
# mock_response = MagicMock()
# mock_response.objects = mock_objects
# mock_fetch = MagicMock()
# self.client_mock.collections.get.return_value.query.fetch_objects = mock_fetch
# mock_fetch.return_value = mock_response
# results = self.weaviate_db.list(limit=10)
# mock_fetch.assert_called_once()
# # Verify results
# self.assertEqual(len(results), 1)
# self.assertEqual(len(results[0]), 2)