From d85fcda0378e3e2ba8772705b075049484bd8fa8 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Thu, 22 May 2025 01:17:29 +0530 Subject: [PATCH] Formatting (#2750) --- .github/workflows/ci.yml | 1 + cookbooks/customer-support-chatbot.ipynb | 36 +- cookbooks/mem0-autogen.ipynb | 30 +- evaluation/evals.py | 57 ++- evaluation/generate_scores.py | 20 +- evaluation/metrics/llm_judge.py | 56 +-- evaluation/metrics/utils.py | 108 ++-- evaluation/prompts.py | 2 +- evaluation/run_experiments.py | 59 +-- evaluation/src/langmem.py | 76 ++- evaluation/src/memzero/add.py | 36 +- evaluation/src/memzero/search.py | 160 +++--- evaluation/src/openai/predict.py | 49 +- evaluation/src/rag.py | 83 ++-- evaluation/src/utils.py | 13 +- evaluation/src/zep/add.py | 20 +- evaluation/src/zep/search.py | 63 ++- examples/graph-db-demo/memgraph-example.ipynb | 10 +- examples/misc/fitness_checker.py | 107 ++-- .../misc/healthcare_assistant_google_adk.py | 77 +-- examples/misc/movie_recommendation_grok3.py | 30 +- examples/misc/personal_assistant_agno.py | 31 +- examples/misc/study_buddy.py | 28 +- examples/misc/voice_assistant_elevenlabs.py | 47 +- mem0/client/main.py | 77 +-- mem0/configs/vector_stores/opensearch.py | 39 +- mem0/embeddings/aws_bedrock.py | 6 +- mem0/embeddings/huggingface.py | 1 + mem0/graphs/configs.py | 3 +- mem0/llms/aws_bedrock.py | 26 +- mem0/memory/main.py | 463 +++++++++--------- mem0/memory/memgraph_memory.py | 76 +-- mem0/memory/storage.py | 14 +- mem0/memory/telemetry.py | 4 +- mem0/utils/factory.py | 3 +- mem0/vector_stores/azure_ai_search.py | 1 - mem0/vector_stores/base.py | 2 +- mem0/vector_stores/chroma.py | 2 +- mem0/vector_stores/elasticsearch.py | 9 +- mem0/vector_stores/faiss.py | 2 +- mem0/vector_stores/langchain.py | 6 +- mem0/vector_stores/milvus.py | 2 +- mem0/vector_stores/opensearch.py | 100 +--- mem0/vector_stores/pgvector.py | 2 +- mem0/vector_stores/qdrant.py | 2 +- mem0/vector_stores/redis.py | 11 +- mem0/vector_stores/supabase.py | 2 +- mem0/vector_stores/upstash_vector.py | 3 +- mem0/vector_stores/weaviate.py | 2 +- server/main.py | 39 +- tests/configs/test_prompts.py | 10 +- tests/embeddings/test_lm_studio_embeddings.py | 4 +- tests/embeddings/test_openai_embeddings.py | 14 +- tests/embeddings/test_vertexai_embeddings.py | 38 +- tests/llms/test_azure_openai.py | 2 +- tests/llms/test_deepseek.py | 28 +- tests/llms/test_langchain.py | 51 +- tests/llms/test_lm_studio.py | 4 +- tests/memory/test_main.py | 77 ++- tests/test_main.py | 50 +- tests/test_proxy.py | 2 +- tests/vector_stores/test_azure_ai_search.py | 161 +++--- tests/vector_stores/test_elasticsearch.py | 166 +++---- tests/vector_stores/test_faiss.py | 148 +++--- .../test_langchain_vector_store.py | 44 +- tests/vector_stores/test_opensearch.py | 61 ++- tests/vector_stores/test_pinecone.py | 39 +- tests/vector_stores/test_supabase.py | 45 +- tests/vector_stores/test_upstash_vector.py | 60 +-- .../test_vertex_ai_vector_search.py | 80 +-- tests/vector_stores/test_weaviate.py | 34 +- 71 files changed, 1391 insertions(+), 1823 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 325e6654..ef6bf158 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -56,6 +56,7 @@ jobs: run: | make install_all pip install -e ".[test]" + pip install pinecone pinecone-text if: steps.cached-hatch-dependencies.outputs.cache-hit != 'true' - name: Run Formatting run: | diff --git a/cookbooks/customer-support-chatbot.ipynb b/cookbooks/customer-support-chatbot.ipynb index 863f156e..fe7dd6fd 100644 --- a/cookbooks/customer-support-chatbot.ipynb +++ b/cookbooks/customer-support-chatbot.ipynb @@ -13,7 +13,7 @@ "import anthropic\n", "\n", "# Set up environment variables\n", - "os.environ[\"OPENAI_API_KEY\"] = \"your_openai_api_key\" # needed for embedding model\n", + "os.environ[\"OPENAI_API_KEY\"] = \"your_openai_api_key\" # needed for embedding model\n", "os.environ[\"ANTHROPIC_API_KEY\"] = \"your_anthropic_api_key\"" ] }, @@ -33,7 +33,7 @@ " \"model\": \"claude-3-5-sonnet-latest\",\n", " \"temperature\": 0.1,\n", " \"max_tokens\": 2000,\n", - " }\n", + " },\n", " }\n", " }\n", " self.client = anthropic.Client(api_key=os.environ[\"ANTHROPIC_API_KEY\"])\n", @@ -50,11 +50,7 @@ " - Keep track of open issues and follow-ups\n", " \"\"\"\n", "\n", - " def store_customer_interaction(self,\n", - " user_id: str,\n", - " message: str,\n", - " response: str,\n", - " metadata: Dict = None):\n", + " def store_customer_interaction(self, user_id: str, message: str, response: str, metadata: Dict = None):\n", " \"\"\"Store customer interaction in memory.\"\"\"\n", " if metadata is None:\n", " metadata = {}\n", @@ -63,24 +59,17 @@ " metadata[\"timestamp\"] = datetime.now().isoformat()\n", "\n", " # Format conversation for storage\n", - " conversation = [\n", - " {\"role\": \"user\", \"content\": message},\n", - " {\"role\": \"assistant\", \"content\": response}\n", - " ]\n", + " conversation = [{\"role\": \"user\", \"content\": message}, {\"role\": \"assistant\", \"content\": response}]\n", "\n", " # Store in Mem0\n", - " self.memory.add(\n", - " conversation,\n", - " user_id=user_id,\n", - " metadata=metadata\n", - " )\n", + " self.memory.add(conversation, user_id=user_id, metadata=metadata)\n", "\n", " def get_relevant_history(self, user_id: str, query: str) -> List[Dict]:\n", " \"\"\"Retrieve relevant past interactions.\"\"\"\n", " return self.memory.search(\n", " query=query,\n", " user_id=user_id,\n", - " limit=5 # Adjust based on needs\n", + " limit=5, # Adjust based on needs\n", " )\n", "\n", " def handle_customer_query(self, user_id: str, query: str) -> str:\n", @@ -112,15 +101,12 @@ " model=\"claude-3-5-sonnet-latest\",\n", " messages=[{\"role\": \"user\", \"content\": prompt}],\n", " max_tokens=2000,\n", - " temperature=0.1\n", + " temperature=0.1,\n", " )\n", "\n", " # Store interaction\n", " self.store_customer_interaction(\n", - " user_id=user_id,\n", - " message=query,\n", - " response=response,\n", - " metadata={\"type\": \"support_query\"}\n", + " user_id=user_id, message=query, response=response, metadata={\"type\": \"support_query\"}\n", " )\n", "\n", " return response.content[0].text" @@ -203,12 +189,12 @@ " # Get user input\n", " query = input()\n", " print(\"Customer:\", query)\n", - " \n", + "\n", " # Check if user wants to exit\n", - " if query.lower() == 'exit':\n", + " if query.lower() == \"exit\":\n", " print(\"Thank you for using our support service. Goodbye!\")\n", " break\n", - " \n", + "\n", " # Handle the query and print the response\n", " response = chatbot.handle_customer_query(user_id, query)\n", " print(\"Support:\", response, \"\\n\\n\")" diff --git a/cookbooks/mem0-autogen.ipynb b/cookbooks/mem0-autogen.ipynb index 43bbb9f2..6e9adb0f 100644 --- a/cookbooks/mem0-autogen.ipynb +++ b/cookbooks/mem0-autogen.ipynb @@ -25,7 +25,8 @@ "source": [ "# Set up ENV Vars\n", "import os\n", - "os.environ['OPENAI_API_KEY'] = \"sk-xxx\"\n" + "\n", + "os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\"" ] }, { @@ -133,11 +134,9 @@ "assistant_id = os.environ.get(\"ASSISTANT_ID\", None)\n", "\n", "# LLM Configuration\n", - "CACHE_SEED = 42 # choose your poison\n", + "CACHE_SEED = 42 # choose your poison\n", "llm_config = {\n", - " \"config_list\": [\n", - " {\"model\": \"gpt-4o\", \"api_key\": os.environ[\"OPENAI_API_KEY\"]}\n", - " ],\n", + " \"config_list\": [{\"model\": \"gpt-4o\", \"api_key\": os.environ[\"OPENAI_API_KEY\"]}],\n", " \"cache_seed\": CACHE_SEED,\n", " \"timeout\": 120,\n", " \"temperature\": 0.0,\n", @@ -348,7 +347,7 @@ "source": [ "# Retrieve the memory\n", "relevant_memories = MEM0_MEMORY_CLIENT.search(user_query, user_id=USER_ID, limit=3)\n", - "relevant_memories_text = '\\n'.join(mem['memory'] for mem in relevant_memories)\n", + "relevant_memories_text = \"\\n\".join(mem[\"memory\"] for mem in relevant_memories)\n", "print(\"Relevant memories:\")\n", "print(relevant_memories_text)\n", "\n", @@ -389,8 +388,8 @@ "# - Enables more context-aware and personalized agent responses.\n", "# - Bridges the gap between human input and AI processing in complex workflows.\n", "\n", - "class Mem0ProxyCoderAgent(UserProxyAgent):\n", "\n", + "class Mem0ProxyCoderAgent(UserProxyAgent):\n", " def __init__(self, *args, **kwargs):\n", " super().__init__(*args, **kwargs)\n", " self.memory = MEM0_MEMORY_CLIENT\n", @@ -399,15 +398,14 @@ " def initiate_chat(self, assistant, message):\n", " # Retrieve memory for the agent\n", " agent_memories = self.memory.search(message, agent_id=self.agent_id, limit=3)\n", - " agent_memories_txt = '\\n'.join(mem['memory'] for mem in agent_memories)\n", + " agent_memories_txt = \"\\n\".join(mem[\"memory\"] for mem in agent_memories)\n", " prompt = f\"{message}\\n Coding Preferences: \\n{str(agent_memories_txt)}\"\n", " response = super().initiate_chat(assistant, message=prompt)\n", " # Add new memory after processing the message\n", " response_dist = response.__dict__ if not isinstance(response, dict) else response\n", " MEMORY_DATA = [{\"role\": \"user\", \"content\": message}, {\"role\": \"assistant\", \"content\": response_dist}]\n", " self.memory.add(MEMORY_DATA, agent_id=self.agent_id)\n", - " return response\n", - " " + " return response" ] }, { @@ -560,12 +558,12 @@ "from cookbooks.helper.mem0_teachability import Mem0Teachability\n", "\n", "teachability = Mem0Teachability(\n", - " verbosity=2, # for visibility of what's happening\n", - " recall_threshold=0.5,\n", - " reset_db=False, # Use True to force-reset the memo DB, and False to use an existing DB.\n", - " agent_id=AGENT_ID,\n", - " memory_client = MEM0_MEMORY_CLIENT,\n", - " )\n", + " verbosity=2, # for visibility of what's happening\n", + " recall_threshold=0.5,\n", + " reset_db=False, # Use True to force-reset the memo DB, and False to use an existing DB.\n", + " agent_id=AGENT_ID,\n", + " memory_client=MEM0_MEMORY_CLIENT,\n", + ")\n", "teachability.add_to_agent(user_proxy)" ] }, diff --git a/evaluation/evals.py b/evaluation/evals.py index 28cc82aa..5045acf3 100644 --- a/evaluation/evals.py +++ b/evaluation/evals.py @@ -14,46 +14,47 @@ def process_item(item_data): local_results = defaultdict(list) for item in v: - gt_answer = str(item['answer']) - pred_answer = str(item['response']) - category = str(item['category']) - question = str(item['question']) + gt_answer = str(item["answer"]) + pred_answer = str(item["response"]) + category = str(item["category"]) + question = str(item["question"]) # Skip category 5 - if category == '5': + if category == "5": continue metrics = calculate_metrics(pred_answer, gt_answer) bleu_scores = calculate_bleu_scores(pred_answer, gt_answer) llm_score = evaluate_llm_judge(question, gt_answer, pred_answer) - local_results[k].append({ - "question": question, - "answer": gt_answer, - "response": pred_answer, - "category": category, - "bleu_score": bleu_scores["bleu1"], - "f1_score": metrics["f1"], - "llm_score": llm_score - }) + local_results[k].append( + { + "question": question, + "answer": gt_answer, + "response": pred_answer, + "category": category, + "bleu_score": bleu_scores["bleu1"], + "f1_score": metrics["f1"], + "llm_score": llm_score, + } + ) return local_results def main(): - parser = argparse.ArgumentParser(description='Evaluate RAG results') - parser.add_argument('--input_file', type=str, - default="results/rag_results_500_k1.json", - help='Path to the input dataset file') - parser.add_argument('--output_file', type=str, - default="evaluation_metrics.json", - help='Path to save the evaluation results') - parser.add_argument('--max_workers', type=int, default=10, - help='Maximum number of worker threads') + parser = argparse.ArgumentParser(description="Evaluate RAG results") + parser.add_argument( + "--input_file", type=str, default="results/rag_results_500_k1.json", help="Path to the input dataset file" + ) + parser.add_argument( + "--output_file", type=str, default="evaluation_metrics.json", help="Path to save the evaluation results" + ) + parser.add_argument("--max_workers", type=int, default=10, help="Maximum number of worker threads") args = parser.parse_args() - with open(args.input_file, 'r') as f: + with open(args.input_file, "r") as f: data = json.load(f) results = defaultdict(list) @@ -61,18 +62,16 @@ def main(): # Use ThreadPoolExecutor with specified workers with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor: - futures = [executor.submit(process_item, item_data) - for item_data in data.items()] + futures = [executor.submit(process_item, item_data) for item_data in data.items()] - for future in tqdm(concurrent.futures.as_completed(futures), - total=len(futures)): + for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)): local_results = future.result() with results_lock: for k, items in local_results.items(): results[k].extend(items) # Save results to JSON file - with open(args.output_file, 'w') as f: + with open(args.output_file, "w") as f: json.dump(results, f, indent=4) print(f"Results saved to {args.output_file}") diff --git a/evaluation/generate_scores.py b/evaluation/generate_scores.py index c6dfde32..8cb4e848 100644 --- a/evaluation/generate_scores.py +++ b/evaluation/generate_scores.py @@ -3,7 +3,7 @@ import json import pandas as pd # Load the evaluation metrics data -with open('evaluation_metrics.json', 'r') as f: +with open("evaluation_metrics.json", "r") as f: data = json.load(f) # Flatten the data into a list of question items @@ -15,28 +15,20 @@ for key in data: df = pd.DataFrame(all_items) # Convert category to numeric type -df['category'] = pd.to_numeric(df['category']) +df["category"] = pd.to_numeric(df["category"]) # Calculate mean scores by category -result = df.groupby('category').agg({ - 'bleu_score': 'mean', - 'f1_score': 'mean', - 'llm_score': 'mean' -}).round(4) +result = df.groupby("category").agg({"bleu_score": "mean", "f1_score": "mean", "llm_score": "mean"}).round(4) # Add count of questions per category -result['count'] = df.groupby('category').size() +result["count"] = df.groupby("category").size() # Print the results print("Mean Scores Per Category:") print(result) # Calculate overall means -overall_means = df.agg({ - 'bleu_score': 'mean', - 'f1_score': 'mean', - 'llm_score': 'mean' -}).round(4) +overall_means = df.agg({"bleu_score": "mean", "f1_score": "mean", "llm_score": "mean"}).round(4) print("\nOverall Mean Scores:") -print(overall_means) \ No newline at end of file +print(overall_means) diff --git a/evaluation/metrics/llm_judge.py b/evaluation/metrics/llm_judge.py index a0a847d7..4d0ec376 100644 --- a/evaluation/metrics/llm_judge.py +++ b/evaluation/metrics/llm_judge.py @@ -33,35 +33,34 @@ Do NOT include both CORRECT and WRONG in your response, or it will break the eva Just return the label CORRECT or WRONG in a json format with the key as "label". """ + def evaluate_llm_judge(question, gold_answer, generated_answer): """Evaluate the generated answer against the gold answer using an LLM judge.""" response = client.chat.completions.create( model="gpt-4o-mini", - messages=[{ - "role": "user", - "content": ACCURACY_PROMPT.format( - question=question, - gold_answer=gold_answer, - generated_answer=generated_answer - ) - }], + messages=[ + { + "role": "user", + "content": ACCURACY_PROMPT.format( + question=question, gold_answer=gold_answer, generated_answer=generated_answer + ), + } + ], response_format={"type": "json_object"}, - temperature=0.0 + temperature=0.0, ) - label = json.loads(response.choices[0].message.content)['label'] + label = json.loads(response.choices[0].message.content)["label"] return 1 if label == "CORRECT" else 0 def main(): """Main function to evaluate RAG results using LLM judge.""" - parser = argparse.ArgumentParser( - description='Evaluate RAG results using LLM judge' - ) + parser = argparse.ArgumentParser(description="Evaluate RAG results using LLM judge") parser.add_argument( - '--input_file', + "--input_file", type=str, default="results/default_run_v4_k30_new_graph.json", - help='Path to the input dataset file' + help="Path to the input dataset file", ) args = parser.parse_args() @@ -78,10 +77,10 @@ def main(): index = 0 for k, v in data.items(): for x in v: - question = x['question'] - gold_answer = x['answer'] - generated_answer = x['response'] - category = x['category'] + question = x["question"] + gold_answer = x["answer"] + generated_answer = x["response"] + category = x["category"] # Skip category 5 if int(category) == 5: @@ -92,13 +91,15 @@ def main(): LLM_JUDGE[category].append(label) # Store the results - RESULTS[index].append({ - "question": question, - "gt_answer": gold_answer, - "response": generated_answer, - "category": category, - "llm_label": label - }) + RESULTS[index].append( + { + "question": question, + "gt_answer": gold_answer, + "response": generated_answer, + "category": category, + "llm_label": label, + } + ) # Save intermediate results with open(output_path, "w") as f: @@ -108,8 +109,7 @@ def main(): print("All categories accuracy:") for cat, results in LLM_JUDGE.items(): if results: # Only print if there are results for this category - print(f" Category {cat}: {np.mean(results):.4f} " - f"({sum(results)}/{len(results)})") + print(f" Category {cat}: {np.mean(results):.4f} " f"({sum(results)}/{len(results)})") print("------------------------------------------") index += 1 diff --git a/evaluation/metrics/utils.py b/evaluation/metrics/utils.py index b3044c90..a832d5ad 100644 --- a/evaluation/metrics/utils.py +++ b/evaluation/metrics/utils.py @@ -3,7 +3,7 @@ Borrowed from https://github.com/WujiangXu/AgenticMemory/blob/main/utils.py @article{xu2025mem, title={A-mem: Agentic memory for llm agents}, - author={Xu, Wujiang and Liang, Zujie and Mei, Kai and Gao, Hang and Tan, Juntao + author={Xu, Wujiang and Liang, Zujie and Mei, Kai and Gao, Hang and Tan, Juntao and Zhang, Yongfeng}, journal={arXiv preprint arXiv:2502.12110}, year={2025} @@ -26,42 +26,45 @@ from sentence_transformers.util import pytorch_cos_sim # Download required NLTK data try: - nltk.download('punkt', quiet=True) - nltk.download('wordnet', quiet=True) + nltk.download("punkt", quiet=True) + nltk.download("wordnet", quiet=True) except Exception as e: print(f"Error downloading NLTK data: {e}") # Initialize SentenceTransformer model (this will be reused) try: - sentence_model = SentenceTransformer('all-MiniLM-L6-v2') + sentence_model = SentenceTransformer("all-MiniLM-L6-v2") except Exception as e: print(f"Warning: Could not load SentenceTransformer model: {e}") sentence_model = None + def simple_tokenize(text): """Simple tokenization function.""" # Convert to string if not already text = str(text) - return text.lower().replace('.', ' ').replace(',', ' ').replace('!', ' ').replace('?', ' ').split() + return text.lower().replace(".", " ").replace(",", " ").replace("!", " ").replace("?", " ").split() + def calculate_rouge_scores(prediction: str, reference: str) -> Dict[str, float]: """Calculate ROUGE scores for prediction against reference.""" - scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True) + scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True) scores = scorer.score(reference, prediction) return { - 'rouge1_f': scores['rouge1'].fmeasure, - 'rouge2_f': scores['rouge2'].fmeasure, - 'rougeL_f': scores['rougeL'].fmeasure + "rouge1_f": scores["rouge1"].fmeasure, + "rouge2_f": scores["rouge2"].fmeasure, + "rougeL_f": scores["rougeL"].fmeasure, } + def calculate_bleu_scores(prediction: str, reference: str) -> Dict[str, float]: """Calculate BLEU scores with different n-gram settings.""" pred_tokens = nltk.word_tokenize(prediction.lower()) ref_tokens = [nltk.word_tokenize(reference.lower())] - + weights_list = [(1, 0, 0, 0), (0.5, 0.5, 0, 0), (0.33, 0.33, 0.33, 0), (0.25, 0.25, 0.25, 0.25)] smooth = SmoothingFunction().method1 - + scores = {} for n, weights in enumerate(weights_list, start=1): try: @@ -69,26 +72,20 @@ def calculate_bleu_scores(prediction: str, reference: str) -> Dict[str, float]: except Exception as e: print(f"Error calculating BLEU score: {e}") score = 0.0 - scores[f'bleu{n}'] = score - + scores[f"bleu{n}"] = score + return scores + def calculate_bert_scores(prediction: str, reference: str) -> Dict[str, float]: """Calculate BERTScore for semantic similarity.""" try: - P, R, F1 = bert_score([prediction], [reference], lang='en', verbose=False) - return { - 'bert_precision': P.item(), - 'bert_recall': R.item(), - 'bert_f1': F1.item() - } + P, R, F1 = bert_score([prediction], [reference], lang="en", verbose=False) + return {"bert_precision": P.item(), "bert_recall": R.item(), "bert_f1": F1.item()} except Exception as e: print(f"Error calculating BERTScore: {e}") - return { - 'bert_precision': 0.0, - 'bert_recall': 0.0, - 'bert_f1': 0.0 - } + return {"bert_precision": 0.0, "bert_recall": 0.0, "bert_f1": 0.0} + def calculate_meteor_score(prediction: str, reference: str) -> float: """Calculate METEOR score for the prediction.""" @@ -98,6 +95,7 @@ def calculate_meteor_score(prediction: str, reference: str) -> float: print(f"Error calculating METEOR score: {e}") return 0.0 + def calculate_sentence_similarity(prediction: str, reference: str) -> float: """Calculate sentence embedding similarity using SentenceBERT.""" if sentence_model is None: @@ -106,7 +104,7 @@ def calculate_sentence_similarity(prediction: str, reference: str) -> float: # Encode sentences embedding1 = sentence_model.encode([prediction], convert_to_tensor=True) embedding2 = sentence_model.encode([reference], convert_to_tensor=True) - + # Calculate cosine similarity similarity = pytorch_cos_sim(embedding1, embedding2).item() return float(similarity) @@ -114,6 +112,7 @@ def calculate_sentence_similarity(prediction: str, reference: str) -> float: print(f"Error calculating sentence similarity: {e}") return 0.0 + def calculate_metrics(prediction: str, reference: str) -> Dict[str, float]: """Calculate comprehensive evaluation metrics for a prediction.""" # Handle empty or None values @@ -130,31 +129,31 @@ def calculate_metrics(prediction: str, reference: str) -> Dict[str, float]: "bleu4": 0.0, "bert_f1": 0.0, "meteor": 0.0, - "sbert_similarity": 0.0 + "sbert_similarity": 0.0, } - + # Convert to strings if they're not already prediction = str(prediction).strip() reference = str(reference).strip() - + # Calculate exact match exact_match = int(prediction.lower() == reference.lower()) - + # Calculate token-based F1 score pred_tokens = set(simple_tokenize(prediction)) ref_tokens = set(simple_tokenize(reference)) common_tokens = pred_tokens & ref_tokens - + if not pred_tokens or not ref_tokens: f1 = 0.0 else: precision = len(common_tokens) / len(pred_tokens) recall = len(common_tokens) / len(ref_tokens) f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 - + # Calculate all scores bleu_scores = calculate_bleu_scores(prediction, reference) - + # Combine all metrics metrics = { "exact_match": exact_match, @@ -164,48 +163,49 @@ def calculate_metrics(prediction: str, reference: str) -> Dict[str, float]: return metrics -def aggregate_metrics(all_metrics: List[Dict[str, float]], all_categories: List[int]) -> Dict[str, Dict[str, Union[float, Dict[str, float]]]]: + +def aggregate_metrics( + all_metrics: List[Dict[str, float]], all_categories: List[int] +) -> Dict[str, Dict[str, Union[float, Dict[str, float]]]]: """Calculate aggregate statistics for all metrics, split by category.""" if not all_metrics: return {} - + # Initialize aggregates for overall and per-category metrics aggregates = defaultdict(list) category_aggregates = defaultdict(lambda: defaultdict(list)) - + # Collect all values for each metric, both overall and per category for metrics, category in zip(all_metrics, all_categories): for metric_name, value in metrics.items(): aggregates[metric_name].append(value) category_aggregates[category][metric_name].append(value) - + # Calculate statistics for overall metrics - results = { - "overall": {} - } - + results = {"overall": {}} + for metric_name, values in aggregates.items(): results["overall"][metric_name] = { - 'mean': statistics.mean(values), - 'std': statistics.stdev(values) if len(values) > 1 else 0.0, - 'median': statistics.median(values), - 'min': min(values), - 'max': max(values), - 'count': len(values) + "mean": statistics.mean(values), + "std": statistics.stdev(values) if len(values) > 1 else 0.0, + "median": statistics.median(values), + "min": min(values), + "max": max(values), + "count": len(values), } - + # Calculate statistics for each category for category in sorted(category_aggregates.keys()): results[f"category_{category}"] = {} for metric_name, values in category_aggregates[category].items(): if values: # Only calculate if we have values for this category results[f"category_{category}"][metric_name] = { - 'mean': statistics.mean(values), - 'std': statistics.stdev(values) if len(values) > 1 else 0.0, - 'median': statistics.median(values), - 'min': min(values), - 'max': max(values), - 'count': len(values) + "mean": statistics.mean(values), + "std": statistics.stdev(values) if len(values) > 1 else 0.0, + "median": statistics.median(values), + "min": min(values), + "max": max(values), + "count": len(values), } - + return results diff --git a/evaluation/prompts.py b/evaluation/prompts.py index e591119d..1d857b4f 100644 --- a/evaluation/prompts.py +++ b/evaluation/prompts.py @@ -144,4 +144,4 @@ ANSWER_PROMPT_ZEP = """ Question: {{question}} Answer: - """ \ No newline at end of file + """ diff --git a/evaluation/run_experiments.py b/evaluation/run_experiments.py index 3424792c..374602e1 100644 --- a/evaluation/run_experiments.py +++ b/evaluation/run_experiments.py @@ -21,23 +21,15 @@ class Experiment: def main(): - parser = argparse.ArgumentParser(description='Run memory experiments') - parser.add_argument('--technique_type', choices=TECHNIQUES, default='mem0', - help='Memory technique to use') - parser.add_argument('--method', choices=METHODS, default='add', - help='Method to use') - parser.add_argument('--chunk_size', type=int, default=1000, - help='Chunk size for processing') - parser.add_argument('--output_folder', type=str, default='results/', - help='Output path for results') - parser.add_argument('--top_k', type=int, default=30, - help='Number of top memories to retrieve') - parser.add_argument('--filter_memories', action='store_true', default=False, - help='Whether to filter memories') - parser.add_argument('--is_graph', action='store_true', default=False, - help='Whether to use graph-based search') - parser.add_argument('--num_chunks', type=int, default=1, - help='Number of chunks to process') + parser = argparse.ArgumentParser(description="Run memory experiments") + parser.add_argument("--technique_type", choices=TECHNIQUES, default="mem0", help="Memory technique to use") + parser.add_argument("--method", choices=METHODS, default="add", help="Method to use") + parser.add_argument("--chunk_size", type=int, default=1000, help="Chunk size for processing") + parser.add_argument("--output_folder", type=str, default="results/", help="Output path for results") + parser.add_argument("--top_k", type=int, default=30, help="Number of top memories to retrieve") + parser.add_argument("--filter_memories", action="store_true", default=False, help="Whether to filter memories") + parser.add_argument("--is_graph", action="store_true", default=False, help="Whether to use graph-based search") + parser.add_argument("--num_chunks", type=int, default=1, help="Number of chunks to process") args = parser.parse_args() @@ -46,33 +38,18 @@ def main(): if args.technique_type == "mem0": if args.method == "add": - memory_manager = MemoryADD( - data_path='dataset/locomo10.json', - is_graph=args.is_graph - ) + memory_manager = MemoryADD(data_path="dataset/locomo10.json", is_graph=args.is_graph) memory_manager.process_all_conversations() elif args.method == "search": output_file_path = os.path.join( args.output_folder, - f"mem0_results_top_{args.top_k}_filter_{args.filter_memories}_graph_{args.is_graph}.json" + f"mem0_results_top_{args.top_k}_filter_{args.filter_memories}_graph_{args.is_graph}.json", ) - memory_searcher = MemorySearch( - output_file_path, - args.top_k, - args.filter_memories, - args.is_graph - ) - memory_searcher.process_data_file('dataset/locomo10.json') + memory_searcher = MemorySearch(output_file_path, args.top_k, args.filter_memories, args.is_graph) + memory_searcher.process_data_file("dataset/locomo10.json") elif args.technique_type == "rag": - output_file_path = os.path.join( - args.output_folder, - f"rag_results_{args.chunk_size}_k{args.num_chunks}.json" - ) - rag_manager = RAGManager( - data_path="dataset/locomo10_rag.json", - chunk_size=args.chunk_size, - k=args.num_chunks - ) + output_file_path = os.path.join(args.output_folder, f"rag_results_{args.chunk_size}_k{args.num_chunks}.json") + rag_manager = RAGManager(data_path="dataset/locomo10_rag.json", chunk_size=args.chunk_size, k=args.num_chunks) rag_manager.process_all_conversations(output_file_path) elif args.technique_type == "langmem": output_file_path = os.path.join(args.output_folder, "langmem_results.json") @@ -85,11 +62,7 @@ def main(): elif args.method == "search": output_file_path = os.path.join(args.output_folder, "zep_search_results.json") zep_manager = ZepSearch() - zep_manager.process_data_file( - "dataset/locomo10.json", - "1", - output_file_path - ) + zep_manager.process_data_file("dataset/locomo10.json", "1", output_file_path) elif args.technique_type == "openai": output_file_path = os.path.join(args.output_folder, "openai_results.json") openai_manager = OpenAIPredict() diff --git a/evaluation/src/langmem.py b/evaluation/src/langmem.py index b3dd720f..033343e4 100644 --- a/evaluation/src/langmem.py +++ b/evaluation/src/langmem.py @@ -28,14 +28,12 @@ def get_answer(question, speaker_1_user_id, speaker_1_memories, speaker_2_user_i speaker_1_user_id=speaker_1_user_id, speaker_1_memories=speaker_1_memories, speaker_2_user_id=speaker_2_user_id, - speaker_2_memories=speaker_2_memories + speaker_2_memories=speaker_2_memories, ) t1 = time.time() response = client.chat.completions.create( - model=os.getenv("MODEL"), - messages=[{"role": "system", "content": prompt}], - temperature=0.0 + model=os.getenv("MODEL"), messages=[{"role": "system", "content": prompt}], temperature=0.0 ) t2 = time.time() return response.choices[0].message.content, t2 - t1 @@ -59,7 +57,9 @@ def prompt(state): class LangMem: - def __init__(self,): + def __init__( + self, + ): self.store = InMemoryStore( index={ "dims": 1536, @@ -80,18 +80,12 @@ class LangMem: ) def add_memory(self, message, config): - return self.agent.invoke( - {"messages": [{"role": "user", "content": message}]}, - config=config - ) + return self.agent.invoke({"messages": [{"role": "user", "content": message}]}, config=config) def search_memory(self, query, config): try: t1 = time.time() - response = self.agent.invoke( - {"messages": [{"role": "user", "content": query}]}, - config=config - ) + response = self.agent.invoke({"messages": [{"role": "user", "content": query}]}, config=config) t2 = time.time() return response["messages"][-1].content, t2 - t1 except Exception as e: @@ -102,7 +96,7 @@ class LangMem: class LangMemManager: def __init__(self, dataset_path): self.dataset_path = dataset_path - with open(self.dataset_path, 'r') as f: + with open(self.dataset_path, "r") as f: self.data = json.load(f) def process_all_conversations(self, output_file_path): @@ -123,7 +117,7 @@ class LangMemManager: # Identify speakers for conv in chat_history: - speakers.add(conv['speaker']) + speakers.add(conv["speaker"]) if len(speakers) != 2: raise ValueError(f"Expected 2 speakers, got {len(speakers)}") @@ -134,50 +128,52 @@ class LangMemManager: # Add memories for each message for conv in tqdm(chat_history, desc=f"Processing messages {key}", leave=False): message = f"{conv['timestamp']} | {conv['speaker']}: {conv['text']}" - if conv['speaker'] == speaker1: + if conv["speaker"] == speaker1: agent1.add_memory(message, config) - elif conv['speaker'] == speaker2: + elif conv["speaker"] == speaker2: agent2.add_memory(message, config) else: raise ValueError(f"Expected speaker1 or speaker2, got {conv['speaker']}") # Process questions for q in tqdm(questions, desc=f"Processing questions {key}", leave=False): - category = q['category'] + category = q["category"] if int(category) == 5: continue - answer = q['answer'] - question = q['question'] + answer = q["answer"] + question = q["question"] response1, speaker1_memory_time = agent1.search_memory(question, config) response2, speaker2_memory_time = agent2.search_memory(question, config) - generated_answer, response_time = get_answer( - question, speaker1, response1, speaker2, response2 - ) + generated_answer, response_time = get_answer(question, speaker1, response1, speaker2, response2) - result[key].append({ - "question": question, - "answer": answer, - "response1": response1, - "response2": response2, - "category": category, - "speaker1_memory_time": speaker1_memory_time, - "speaker2_memory_time": speaker2_memory_time, - "response_time": response_time, - 'response': generated_answer - }) + result[key].append( + { + "question": question, + "answer": answer, + "response1": response1, + "response2": response2, + "category": category, + "speaker1_memory_time": speaker1_memory_time, + "speaker2_memory_time": speaker2_memory_time, + "response_time": response_time, + "response": generated_answer, + } + ) return result # Use multiprocessing to process conversations in parallel with mp.Pool(processes=10) as pool: - results = list(tqdm( - pool.imap(process_conversation, list(self.data.items())), - total=len(self.data), - desc="Processing conversations" - )) + results = list( + tqdm( + pool.imap(process_conversation, list(self.data.items())), + total=len(self.data), + desc="Processing conversations", + ) + ) # Combine results from all workers for result in results: @@ -185,5 +181,5 @@ class LangMemManager: OUTPUT[key].extend(items) # Save final results - with open(output_file_path, 'w') as f: + with open(output_file_path, "w") as f: json.dump(OUTPUT, f, indent=4) diff --git a/evaluation/src/memzero/add.py b/evaluation/src/memzero/add.py index 03b87c5d..7c8bd12e 100644 --- a/evaluation/src/memzero/add.py +++ b/evaluation/src/memzero/add.py @@ -13,7 +13,7 @@ load_dotenv() # Update custom instructions -custom_instructions =""" +custom_instructions = """ Generate personal memories that follow these guidelines: 1. Each memory should be self-contained with complete context, including: @@ -47,7 +47,7 @@ class MemoryADD: self.mem0_client = MemoryClient( api_key=os.getenv("MEM0_API_KEY"), org_id=os.getenv("MEM0_ORGANIZATION_ID"), - project_id=os.getenv("MEM0_PROJECT_ID") + project_id=os.getenv("MEM0_PROJECT_ID"), ) self.mem0_client.update_project(custom_instructions=custom_instructions) @@ -59,15 +59,16 @@ class MemoryADD: self.load_data() def load_data(self): - with open(self.data_path, 'r') as f: + with open(self.data_path, "r") as f: self.data = json.load(f) return self.data def add_memory(self, user_id, message, metadata, retries=3): for attempt in range(retries): try: - _ = self.mem0_client.add(message, user_id=user_id, version="v2", - metadata=metadata, enable_graph=self.is_graph) + _ = self.mem0_client.add( + message, user_id=user_id, version="v2", metadata=metadata, enable_graph=self.is_graph + ) return except Exception as e: if attempt < retries - 1: @@ -78,13 +79,13 @@ class MemoryADD: def add_memories_for_speaker(self, speaker, messages, timestamp, desc): for i in tqdm(range(0, len(messages), self.batch_size), desc=desc): - batch_messages = messages[i:i+self.batch_size] + batch_messages = messages[i : i + self.batch_size] self.add_memory(speaker, batch_messages, metadata={"timestamp": timestamp}) def process_conversation(self, item, idx): - conversation = item['conversation'] - speaker_a = conversation['speaker_a'] - speaker_b = conversation['speaker_b'] + conversation = item["conversation"] + speaker_a = conversation["speaker_a"] + speaker_b = conversation["speaker_b"] speaker_a_user_id = f"{speaker_a}_{idx}" speaker_b_user_id = f"{speaker_b}_{idx}" @@ -94,7 +95,7 @@ class MemoryADD: self.mem0_client.delete_all(user_id=speaker_b_user_id) for key in conversation.keys(): - if key in ['speaker_a', 'speaker_b'] or "date" in key or "timestamp" in key: + if key in ["speaker_a", "speaker_b"] or "date" in key or "timestamp" in key: continue date_time_key = key + "_date_time" @@ -104,10 +105,10 @@ class MemoryADD: messages = [] messages_reverse = [] for chat in chats: - if chat['speaker'] == speaker_a: + if chat["speaker"] == speaker_a: messages.append({"role": "user", "content": f"{speaker_a}: {chat['text']}"}) messages_reverse.append({"role": "assistant", "content": f"{speaker_a}: {chat['text']}"}) - elif chat['speaker'] == speaker_b: + elif chat["speaker"] == speaker_b: messages.append({"role": "assistant", "content": f"{speaker_b}: {chat['text']}"}) messages_reverse.append({"role": "user", "content": f"{speaker_b}: {chat['text']}"}) else: @@ -116,11 +117,11 @@ class MemoryADD: # add memories for the two users on different threads thread_a = threading.Thread( target=self.add_memories_for_speaker, - args=(speaker_a_user_id, messages, timestamp, "Adding Memories for Speaker A") + args=(speaker_a_user_id, messages, timestamp, "Adding Memories for Speaker A"), ) thread_b = threading.Thread( target=self.add_memories_for_speaker, - args=(speaker_b_user_id, messages_reverse, timestamp, "Adding Memories for Speaker B") + args=(speaker_b_user_id, messages_reverse, timestamp, "Adding Memories for Speaker B"), ) thread_a.start() @@ -134,10 +135,7 @@ class MemoryADD: if not self.data: raise ValueError("No data loaded. Please set data_path and call load_data() first.") with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [ - executor.submit(self.process_conversation, item, idx) - for idx, item in enumerate(self.data) - ] + futures = [executor.submit(self.process_conversation, item, idx) for idx, item in enumerate(self.data)] for future in futures: - future.result() \ No newline at end of file + future.result() diff --git a/evaluation/src/memzero/search.py b/evaluation/src/memzero/search.py index cf318b15..cf99c2be 100644 --- a/evaluation/src/memzero/search.py +++ b/evaluation/src/memzero/search.py @@ -16,12 +16,11 @@ load_dotenv() class MemorySearch: - - def __init__(self, output_path='results.json', top_k=10, filter_memories=False, is_graph=False): + def __init__(self, output_path="results.json", top_k=10, filter_memories=False, is_graph=False): self.mem0_client = MemoryClient( api_key=os.getenv("MEM0_API_KEY"), org_id=os.getenv("MEM0_ORGANIZATION_ID"), - project_id=os.getenv("MEM0_PROJECT_ID") + project_id=os.getenv("MEM0_PROJECT_ID"), ) self.top_k = top_k self.openai_client = OpenAI() @@ -42,11 +41,18 @@ class MemorySearch: try: if self.is_graph: print("Searching with graph") - memories = self.mem0_client.search(query, user_id=user_id, top_k=self.top_k, - filter_memories=self.filter_memories, enable_graph=True, output_format='v1.1') + memories = self.mem0_client.search( + query, + user_id=user_id, + top_k=self.top_k, + filter_memories=self.filter_memories, + enable_graph=True, + output_format="v1.1", + ) else: - memories = self.mem0_client.search(query, user_id=user_id, top_k=self.top_k, - filter_memories=self.filter_memories) + memories = self.mem0_client.search( + query, user_id=user_id, top_k=self.top_k, filter_memories=self.filter_memories + ) break except Exception as e: print("Retrying...") @@ -57,64 +63,86 @@ class MemorySearch: end_time = time.time() if not self.is_graph: - semantic_memories = [{'memory': memory['memory'], - 'timestamp': memory['metadata']['timestamp'], - 'score': round(memory['score'], 2)} - for memory in memories] + semantic_memories = [ + { + "memory": memory["memory"], + "timestamp": memory["metadata"]["timestamp"], + "score": round(memory["score"], 2), + } + for memory in memories + ] graph_memories = None else: - semantic_memories = [{'memory': memory['memory'], - 'timestamp': memory['metadata']['timestamp'], - 'score': round(memory['score'], 2)} for memory in memories['results']] - graph_memories = [{"source": relation['source'], "relationship": relation['relationship'], "target": relation['target']} for relation in memories['relations']] + semantic_memories = [ + { + "memory": memory["memory"], + "timestamp": memory["metadata"]["timestamp"], + "score": round(memory["score"], 2), + } + for memory in memories["results"] + ] + graph_memories = [ + {"source": relation["source"], "relationship": relation["relationship"], "target": relation["target"]} + for relation in memories["relations"] + ] return semantic_memories, graph_memories, end_time - start_time def answer_question(self, speaker_1_user_id, speaker_2_user_id, question, answer, category): - speaker_1_memories, speaker_1_graph_memories, speaker_1_memory_time = self.search_memory(speaker_1_user_id, question) - speaker_2_memories, speaker_2_graph_memories, speaker_2_memory_time = self.search_memory(speaker_2_user_id, question) + speaker_1_memories, speaker_1_graph_memories, speaker_1_memory_time = self.search_memory( + speaker_1_user_id, question + ) + speaker_2_memories, speaker_2_graph_memories, speaker_2_memory_time = self.search_memory( + speaker_2_user_id, question + ) - search_1_memory = [f"{item['timestamp']}: {item['memory']}" - for item in speaker_1_memories] - search_2_memory = [f"{item['timestamp']}: {item['memory']}" - for item in speaker_2_memories] + search_1_memory = [f"{item['timestamp']}: {item['memory']}" for item in speaker_1_memories] + search_2_memory = [f"{item['timestamp']}: {item['memory']}" for item in speaker_2_memories] template = Template(self.ANSWER_PROMPT) answer_prompt = template.render( - speaker_1_user_id=speaker_1_user_id.split('_')[0], - speaker_2_user_id=speaker_2_user_id.split('_')[0], + speaker_1_user_id=speaker_1_user_id.split("_")[0], + speaker_2_user_id=speaker_2_user_id.split("_")[0], speaker_1_memories=json.dumps(search_1_memory, indent=4), speaker_2_memories=json.dumps(search_2_memory, indent=4), speaker_1_graph_memories=json.dumps(speaker_1_graph_memories, indent=4), speaker_2_graph_memories=json.dumps(speaker_2_graph_memories, indent=4), - question=question + question=question, ) t1 = time.time() response = self.openai_client.chat.completions.create( - model=os.getenv("MODEL"), - messages=[ - {"role": "system", "content": answer_prompt} - ], - temperature=0.0 + model=os.getenv("MODEL"), messages=[{"role": "system", "content": answer_prompt}], temperature=0.0 ) t2 = time.time() response_time = t2 - t1 - return response.choices[0].message.content, speaker_1_memories, speaker_2_memories, speaker_1_memory_time, speaker_2_memory_time, speaker_1_graph_memories, speaker_2_graph_memories, response_time + return ( + response.choices[0].message.content, + speaker_1_memories, + speaker_2_memories, + speaker_1_memory_time, + speaker_2_memory_time, + speaker_1_graph_memories, + speaker_2_graph_memories, + response_time, + ) def process_question(self, val, speaker_a_user_id, speaker_b_user_id): - question = val.get('question', '') - answer = val.get('answer', '') - category = val.get('category', -1) - evidence = val.get('evidence', []) - adversarial_answer = val.get('adversarial_answer', '') + question = val.get("question", "") + answer = val.get("answer", "") + category = val.get("category", -1) + evidence = val.get("evidence", []) + adversarial_answer = val.get("adversarial_answer", "") - response, speaker_1_memories, speaker_2_memories, speaker_1_memory_time, speaker_2_memory_time, speaker_1_graph_memories, speaker_2_graph_memories, response_time = self.answer_question( - speaker_a_user_id, - speaker_b_user_id, - question, - answer, - category - ) + ( + response, + speaker_1_memories, + speaker_2_memories, + speaker_1_memory_time, + speaker_2_memory_time, + speaker_1_graph_memories, + speaker_2_graph_memories, + response_time, + ) = self.answer_question(speaker_a_user_id, speaker_b_user_id, question, answer, category) result = { "question": question, @@ -125,67 +153,63 @@ class MemorySearch: "adversarial_answer": adversarial_answer, "speaker_1_memories": speaker_1_memories, "speaker_2_memories": speaker_2_memories, - 'num_speaker_1_memories': len(speaker_1_memories), - 'num_speaker_2_memories': len(speaker_2_memories), - 'speaker_1_memory_time': speaker_1_memory_time, - 'speaker_2_memory_time': speaker_2_memory_time, + "num_speaker_1_memories": len(speaker_1_memories), + "num_speaker_2_memories": len(speaker_2_memories), + "speaker_1_memory_time": speaker_1_memory_time, + "speaker_2_memory_time": speaker_2_memory_time, "speaker_1_graph_memories": speaker_1_graph_memories, "speaker_2_graph_memories": speaker_2_graph_memories, - "response_time": response_time + "response_time": response_time, } # Save results after each question is processed - with open(self.output_path, 'w') as f: + with open(self.output_path, "w") as f: json.dump(self.results, f, indent=4) return result def process_data_file(self, file_path): - with open(file_path, 'r') as f: + with open(file_path, "r") as f: data = json.load(f) for idx, item in tqdm(enumerate(data), total=len(data), desc="Processing conversations"): - qa = item['qa'] - conversation = item['conversation'] - speaker_a = conversation['speaker_a'] - speaker_b = conversation['speaker_b'] + qa = item["qa"] + conversation = item["conversation"] + speaker_a = conversation["speaker_a"] + speaker_b = conversation["speaker_b"] speaker_a_user_id = f"{speaker_a}_{idx}" speaker_b_user_id = f"{speaker_b}_{idx}" - for question_item in tqdm(qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False): - result = self.process_question( - question_item, - speaker_a_user_id, - speaker_b_user_id - ) + for question_item in tqdm( + qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False + ): + result = self.process_question(question_item, speaker_a_user_id, speaker_b_user_id) self.results[idx].append(result) # Save results after each question is processed - with open(self.output_path, 'w') as f: + with open(self.output_path, "w") as f: json.dump(self.results, f, indent=4) # Final save at the end - with open(self.output_path, 'w') as f: + with open(self.output_path, "w") as f: json.dump(self.results, f, indent=4) def process_questions_parallel(self, qa_list, speaker_a_user_id, speaker_b_user_id, max_workers=1): def process_single_question(val): result = self.process_question(val, speaker_a_user_id, speaker_b_user_id) # Save results after each question is processed - with open(self.output_path, 'w') as f: + with open(self.output_path, "w") as f: json.dump(self.results, f, indent=4) return result with ThreadPoolExecutor(max_workers=max_workers) as executor: - results = list(tqdm( - executor.map(process_single_question, qa_list), - total=len(qa_list), - desc="Answering Questions" - )) + results = list( + tqdm(executor.map(process_single_question, qa_list), total=len(qa_list), desc="Answering Questions") + ) # Final save at the end - with open(self.output_path, 'w') as f: + with open(self.output_path, "w") as f: json.dump(self.results, f, indent=4) return results diff --git a/evaluation/src/openai/predict.py b/evaluation/src/openai/predict.py index a2cc1d79..de80626a 100644 --- a/evaluation/src/openai/predict.py +++ b/evaluation/src/openai/predict.py @@ -59,23 +59,19 @@ class OpenAIPredict: self.results = defaultdict(list) def search_memory(self, idx): - - with open(f'memories/{idx}.txt', 'r') as file: + with open(f"memories/{idx}.txt", "r") as file: memories = file.read() return memories, 0 def process_question(self, val, idx): - question = val.get('question', '') - answer = val.get('answer', '') - category = val.get('category', -1) - evidence = val.get('evidence', []) - adversarial_answer = val.get('adversarial_answer', '') + question = val.get("question", "") + answer = val.get("answer", "") + category = val.get("category", -1) + evidence = val.get("evidence", []) + adversarial_answer = val.get("adversarial_answer", "") - response, search_memory_time, response_time, context = self.answer_question( - idx, - question - ) + response, search_memory_time, response_time, context = self.answer_question(idx, question) result = { "question": question, @@ -86,7 +82,7 @@ class OpenAIPredict: "adversarial_answer": adversarial_answer, "search_memory_time": search_memory_time, "response_time": response_time, - "context": context + "context": context, } return result @@ -95,43 +91,35 @@ class OpenAIPredict: memories, search_memory_time = self.search_memory(idx) template = Template(ANSWER_PROMPT) - answer_prompt = template.render( - memories=memories, - question=question - ) + answer_prompt = template.render(memories=memories, question=question) t1 = time.time() response = self.openai_client.chat.completions.create( - model=os.getenv("MODEL"), - messages=[ - {"role": "system", "content": answer_prompt} - ], - temperature=0.0 + model=os.getenv("MODEL"), messages=[{"role": "system", "content": answer_prompt}], temperature=0.0 ) t2 = time.time() response_time = t2 - t1 return response.choices[0].message.content, search_memory_time, response_time, memories def process_data_file(self, file_path, output_file_path): - with open(file_path, 'r') as f: + with open(file_path, "r") as f: data = json.load(f) for idx, item in tqdm(enumerate(data), total=len(data), desc="Processing conversations"): - qa = item['qa'] + qa = item["qa"] - for question_item in tqdm(qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False): - result = self.process_question( - question_item, - idx - ) + for question_item in tqdm( + qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False + ): + result = self.process_question(question_item, idx) self.results[idx].append(result) # Save results after each question is processed - with open(output_file_path, 'w') as f: + with open(output_file_path, "w") as f: json.dump(self.results, f, indent=4) # Final save at the end - with open(output_file_path, 'w') as f: + with open(output_file_path, "w") as f: json.dump(self.results, f, indent=4) @@ -141,4 +129,3 @@ if __name__ == "__main__": args = parser.parse_args() openai_predict = OpenAIPredict() openai_predict.process_data_file("../../dataset/locomo10.json", args.output_file_path) - diff --git a/evaluation/src/rag.py b/evaluation/src/rag.py index 4edaf5dd..b3ce5a39 100644 --- a/evaluation/src/rag.py +++ b/evaluation/src/rag.py @@ -33,10 +33,7 @@ class RAGManager: def generate_response(self, question, context): template = Template(PROMPT) - prompt = template.render( - CONTEXT=context, - QUESTION=question - ) + prompt = template.render(CONTEXT=context, QUESTION=question) max_retries = 3 retries = 0 @@ -47,19 +44,21 @@ class RAGManager: response = self.client.chat.completions.create( model=self.model, messages=[ - {"role": "system", - "content": "You are a helpful assistant that can answer " - "questions based on the provided context." - "If the question involves timing, use the conversation date for reference." - "Provide the shortest possible answer." - "Use words directly from the conversation when possible." - "Avoid using subjects in your answer."}, - {"role": "user", "content": prompt} + { + "role": "system", + "content": "You are a helpful assistant that can answer " + "questions based on the provided context." + "If the question involves timing, use the conversation date for reference." + "Provide the shortest possible answer." + "Use words directly from the conversation when possible." + "Avoid using subjects in your answer.", + }, + {"role": "user", "content": prompt}, ], - temperature=0 + temperature=0, ) t2 = time.time() - return response.choices[0].message.content.strip(), t2-t1 + return response.choices[0].message.content.strip(), t2 - t1 except Exception as e: retries += 1 if retries > max_retries: @@ -69,21 +68,16 @@ class RAGManager: def clean_chat_history(self, chat_history): cleaned_chat_history = "" for c in chat_history: - cleaned_chat_history += (f"{c['timestamp']} | {c['speaker']}: " - f"{c['text']}\n") + cleaned_chat_history += f"{c['timestamp']} | {c['speaker']}: " f"{c['text']}\n" return cleaned_chat_history def calculate_embedding(self, document): - response = self.client.embeddings.create( - model=os.getenv("EMBEDDING_MODEL"), - input=document - ) + response = self.client.embeddings.create(model=os.getenv("EMBEDDING_MODEL"), input=document) return response.data[0].embedding def calculate_similarity(self, embedding1, embedding2): - return np.dot(embedding1, embedding2) / ( - np.linalg.norm(embedding1) * np.linalg.norm(embedding2)) + return np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2)) def search(self, query, chunks, embeddings, k=1): """ @@ -101,10 +95,7 @@ class RAGManager: """ t1 = time.time() query_embedding = self.calculate_embedding(query) - similarities = [ - self.calculate_similarity(query_embedding, embedding) - for embedding in embeddings - ] + similarities = [self.calculate_similarity(query_embedding, embedding) for embedding in embeddings] # Get indices of top-k most similar chunks if k == 1: @@ -118,7 +109,7 @@ class RAGManager: combined_chunks = "\n<->\n".join([chunks[i] for i in top_indices]) t2 = time.time() - return combined_chunks, t2-t1 + return combined_chunks, t2 - t1 def create_chunks(self, chat_history, chunk_size=500): """ @@ -139,7 +130,7 @@ class RAGManager: # Split into chunks based on token count for i in range(0, len(tokens), chunk_size): - chunk_tokens = tokens[i:i+chunk_size] + chunk_tokens = tokens[i : i + chunk_size] chunk = encoding.decode(chunk_tokens) chunks.append(chunk) @@ -159,13 +150,9 @@ class RAGManager: chat_history = value["conversation"] questions = value["question"] - chunks, embeddings = self.create_chunks( - chat_history, self.chunk_size - ) + chunks, embeddings = self.create_chunks(chat_history, self.chunk_size) - for item in tqdm( - questions, desc="Answering questions", leave=False - ): + for item in tqdm(questions, desc="Answering questions", leave=False): question = item["question"] answer = item.get("answer", "") category = item["category"] @@ -174,22 +161,20 @@ class RAGManager: context = chunks[0] search_time = 0 else: - context, search_time = self.search( - question, chunks, embeddings, k=self.k - ) - response, response_time = self.generate_response( - question, context - ) + context, search_time = self.search(question, chunks, embeddings, k=self.k) + response, response_time = self.generate_response(question, context) - FINAL_RESULTS[key].append({ - "question": question, - "answer": answer, - "category": category, - "context": context, - "response": response, - "search_time": search_time, - "response_time": response_time, - }) + FINAL_RESULTS[key].append( + { + "question": question, + "answer": answer, + "category": category, + "context": context, + "response": response, + "search_time": search_time, + "response_time": response_time, + } + ) with open(output_file_path, "w+") as f: json.dump(FINAL_RESULTS, f, indent=4) diff --git a/evaluation/src/utils.py b/evaluation/src/utils.py index b8f5ecf5..7ee8e493 100644 --- a/evaluation/src/utils.py +++ b/evaluation/src/utils.py @@ -1,12 +1,3 @@ -TECHNIQUES = [ - "mem0", - "rag", - "langmem", - "zep", - "openai" -] +TECHNIQUES = ["mem0", "rag", "langmem", "zep", "openai"] -METHODS = [ - "add", - "search" -] +METHODS = ["add", "search"] diff --git a/evaluation/src/zep/add.py b/evaluation/src/zep/add.py index 1e05c11c..43198b09 100644 --- a/evaluation/src/zep/add.py +++ b/evaluation/src/zep/add.py @@ -19,12 +19,12 @@ class ZepAdd: self.load_data() def load_data(self): - with open(self.data_path, 'r') as f: + with open(self.data_path, "r") as f: self.data = json.load(f) return self.data def process_conversation(self, run_id, item, idx): - conversation = item['conversation'] + conversation = item["conversation"] user_id = f"run_id_{run_id}_experiment_user_{idx}" session_id = f"run_id_{run_id}_experiment_session_{idx}" @@ -41,7 +41,7 @@ class ZepAdd: print("Starting to add memories... for user", user_id) for key in tqdm(conversation.keys(), desc=f"Processing user {user_id}"): - if key in ['speaker_a', 'speaker_b'] or "date" in key: + if key in ["speaker_a", "speaker_b"] or "date" in key: continue date_time_key = key + "_date_time" @@ -51,11 +51,13 @@ class ZepAdd: for chat in tqdm(chats, desc=f"Adding chats for {key}", leave=False): self.zep_client.memory.add( session_id=session_id, - messages=[Message( - role=chat['speaker'], - role_type="user", - content=f"{timestamp}: {chat['text']}", - )] + messages=[ + Message( + role=chat["speaker"], + role_type="user", + content=f"{timestamp}: {chat['text']}", + ) + ], ) def process_all_conversations(self, run_id): @@ -71,4 +73,4 @@ if __name__ == "__main__": parser.add_argument("--run_id", type=str, required=True) args = parser.parse_args() zep_add = ZepAdd(data_path="../../dataset/locomo10.json") - zep_add.process_all_conversations(args.run_id) \ No newline at end of file + zep_add.process_all_conversations(args.run_id) diff --git a/evaluation/src/zep/search.py b/evaluation/src/zep/search.py index c14c2eac..cfb1df44 100644 --- a/evaluation/src/zep/search.py +++ b/evaluation/src/zep/search.py @@ -42,9 +42,9 @@ class ZepSearch: return f"{edge.valid_at if edge.valid_at else 'date unknown'} - {(edge.invalid_at if edge.invalid_at else 'present')}" def compose_search_context(self, edges: list[EntityEdge], nodes: list[EntityNode]) -> str: - facts = [f' - {edge.fact} ({self.format_edge_date_range(edge)})' for edge in edges] - entities = [f' - {node.name}: {node.summary}' for node in nodes] - return TEMPLATE.format(facts='\n'.join(facts), entities='\n'.join(entities)) + facts = [f" - {edge.fact} ({self.format_edge_date_range(edge)})" for edge in edges] + entities = [f" - {node.name}: {node.summary}" for node in nodes] + return TEMPLATE.format(facts="\n".join(facts), entities="\n".join(entities)) def search_memory(self, run_id, idx, query, max_retries=3, retry_delay=1): start_time = time.time() @@ -52,8 +52,14 @@ class ZepSearch: while retries < max_retries: try: user_id = f"run_id_{run_id}_experiment_user_{idx}" - edges_results = (self.zep_client.graph.search(user_id=user_id, reranker='cross_encoder', query=query, scope='edges', limit=20)).edges - node_results = (self.zep_client.graph.search(user_id=user_id, reranker='rrf', query=query, scope='nodes', limit=20)).nodes + edges_results = ( + self.zep_client.graph.search( + user_id=user_id, reranker="cross_encoder", query=query, scope="edges", limit=20 + ) + ).edges + node_results = ( + self.zep_client.graph.search(user_id=user_id, reranker="rrf", query=query, scope="nodes", limit=20) + ).nodes context = self.compose_search_context(edges_results, node_results) break except Exception as e: @@ -68,17 +74,13 @@ class ZepSearch: return context, end_time - start_time def process_question(self, run_id, val, idx): - question = val.get('question', '') - answer = val.get('answer', '') - category = val.get('category', -1) - evidence = val.get('evidence', []) - adversarial_answer = val.get('adversarial_answer', '') + question = val.get("question", "") + answer = val.get("answer", "") + category = val.get("category", -1) + evidence = val.get("evidence", []) + adversarial_answer = val.get("adversarial_answer", "") - response, search_memory_time, response_time, context = self.answer_question( - run_id, - idx, - question - ) + response, search_memory_time, response_time, context = self.answer_question(run_id, idx, question) result = { "question": question, @@ -89,7 +91,7 @@ class ZepSearch: "adversarial_answer": adversarial_answer, "search_memory_time": search_memory_time, "response_time": response_time, - "context": context + "context": context, } return result @@ -98,44 +100,35 @@ class ZepSearch: context, search_memory_time = self.search_memory(run_id, idx, question) template = Template(ANSWER_PROMPT_ZEP) - answer_prompt = template.render( - memories=context, - question=question - ) + answer_prompt = template.render(memories=context, question=question) t1 = time.time() response = self.openai_client.chat.completions.create( - model=os.getenv("MODEL"), - messages=[ - {"role": "system", "content": answer_prompt} - ], - temperature=0.0 + model=os.getenv("MODEL"), messages=[{"role": "system", "content": answer_prompt}], temperature=0.0 ) t2 = time.time() response_time = t2 - t1 return response.choices[0].message.content, search_memory_time, response_time, context def process_data_file(self, file_path, run_id, output_file_path): - with open(file_path, 'r') as f: + with open(file_path, "r") as f: data = json.load(f) for idx, item in tqdm(enumerate(data), total=len(data), desc="Processing conversations"): - qa = item['qa'] + qa = item["qa"] - for question_item in tqdm(qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False): - result = self.process_question( - run_id, - question_item, - idx - ) + for question_item in tqdm( + qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False + ): + result = self.process_question(run_id, question_item, idx) self.results[idx].append(result) # Save results after each question is processed - with open(output_file_path, 'w') as f: + with open(output_file_path, "w") as f: json.dump(self.results, f, indent=4) # Final save at the end - with open(output_file_path, 'w') as f: + with open(output_file_path, "w") as f: json.dump(self.results, f, indent=4) diff --git a/examples/graph-db-demo/memgraph-example.ipynb b/examples/graph-db-demo/memgraph-example.ipynb index b559b6e2..bd302dcf 100644 --- a/examples/graph-db-demo/memgraph-example.ipynb +++ b/examples/graph-db-demo/memgraph-example.ipynb @@ -56,9 +56,7 @@ "\n", "import os\n", "\n", - "os.environ[\"OPENAI_API_KEY\"] = (\n", - " \"\"\n", - ")" + "os.environ[\"OPENAI_API_KEY\"] = \"\"" ] }, { @@ -149,7 +147,7 @@ " \"role\": \"assistant\",\n", " \"content\": \"Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future.\",\n", " },\n", - "]\n" + "]" ] }, { @@ -166,9 +164,7 @@ "outputs": [], "source": [ "# Store inferred memories (default behavior)\n", - "result = m.add(\n", - " messages, user_id=\"alice\", metadata={\"category\": \"movie_recommendations\"}\n", - ")" + "result = m.add(messages, user_id=\"alice\", metadata={\"category\": \"movie_recommendations\"})" ] }, { diff --git a/examples/misc/fitness_checker.py b/examples/misc/fitness_checker.py index 29d2a3e8..d7f879b3 100644 --- a/examples/misc/fitness_checker.py +++ b/examples/misc/fitness_checker.py @@ -20,19 +20,19 @@ agent = Agent( name="Fitness Agent", model=OpenAIChat(id="gpt-4o"), description="You are a helpful fitness assistant who remembers past logs and gives personalized suggestions for Anish's training and diet.", - markdown=True + markdown=True, ) # Store user preferences as memory def store_user_preferences(conversation: list, user_id: str = USER_ID): """Store user preferences from conversation history""" - memory_client.add(conversation, user_id=user_id, output_format='v1.1') + memory_client.add(conversation, user_id=user_id, output_format="v1.1") # Memory-aware assistant function def fitness_coach(user_input: str, user_id: str = USER_ID): - memories = memory_client.search(user_input, user_id=user_id) # Search relevant memories bases on user query + memories = memory_client.search(user_input, user_id=user_id) # Search relevant memories bases on user query memory_context = "\n".join(f"- {m['memory']}" for m in memories) prompt = f"""You are a fitness assistant who helps Anish with his training, recovery, and diet. You have long-term memory of his health, routines, preferences, and past conversations. @@ -48,113 +48,66 @@ User query: memory_client.add(f"User: {user_input}\nAssistant: {response.content}", user_id=user_id) return response.content + # -------------------------------------------------- # Store user preferences and memories messages = [ { "role": "user", - "content": "Hi, I’m Anish. I'm 26 years old, 5'10\", and weigh 72kg. I started working out 6 months ago with the goal of building lean muscle." + "content": "Hi, I’m Anish. I'm 26 years old, 5'10\", and weigh 72kg. I started working out 6 months ago with the goal of building lean muscle.", }, { "role": "assistant", - "content": "Got it — you're 26, 5'10\", 72kg, and on a lean muscle journey. Started gym 6 months ago." + "content": "Got it — you're 26, 5'10\", 72kg, and on a lean muscle journey. Started gym 6 months ago.", }, { "role": "user", - "content": "I follow a push-pull-legs routine and train 5 times a week. My rest days are Wednesday and Sunday." + "content": "I follow a push-pull-legs routine and train 5 times a week. My rest days are Wednesday and Sunday.", }, { "role": "assistant", - "content": "Understood — push-pull-legs split, training 5x/week with rest on Wednesdays and Sundays." + "content": "Understood — push-pull-legs split, training 5x/week with rest on Wednesdays and Sundays.", }, + {"role": "user", "content": "After push days, I usually eat high-protein and moderate-carb meals to recover."}, + {"role": "assistant", "content": "Noted — high-protein, moderate-carb meals after push workouts."}, + {"role": "user", "content": "For pull days, I take whey protein and eat a banana after training."}, + {"role": "assistant", "content": "Logged — whey protein and banana post pull workouts."}, + {"role": "user", "content": "On leg days, I make sure to have complex carbs like rice or oats."}, + {"role": "assistant", "content": "Noted — complex carbs like rice and oats are part of your leg day meals."}, { "role": "user", - "content": "After push days, I usually eat high-protein and moderate-carb meals to recover." - }, - { - "role": "assistant", - "content": "Noted — high-protein, moderate-carb meals after push workouts." + "content": "I often feel sore after leg days, so I use turmeric milk and magnesium to help with recovery.", }, + {"role": "assistant", "content": "I'll remember turmeric milk and magnesium as part of your leg day recovery."}, { "role": "user", - "content": "For pull days, I take whey protein and eat a banana after training." + "content": "Last push day, I did 3x8 bench press at 60kg, 4x12 overhead press, and dips. Felt fatigued after.", }, { "role": "assistant", - "content": "Logged — whey protein and banana post pull workouts." + "content": "Push day logged — 60kg bench, overhead press, dips. You felt fatigued afterward.", }, + {"role": "user", "content": "I prefer light dinners post-workout like tofu, soup, and vegetables."}, + {"role": "assistant", "content": "Got it — light dinners post-workout: tofu, soup, and veggies."}, { "role": "user", - "content": "On leg days, I make sure to have complex carbs like rice or oats." - }, - { - "role": "assistant", - "content": "Noted — complex carbs like rice and oats are part of your leg day meals." + "content": "I have mild lactose intolerance, so I avoid dairy. I use almond milk or lactose-free whey.", }, + {"role": "assistant", "content": "Understood — avoiding regular dairy, using almond milk and lactose-free whey."}, { "role": "user", - "content": "I often feel sore after leg days, so I use turmeric milk and magnesium to help with recovery." + "content": "I get occasional knee pain, so I avoid deep squats and do more hamstring curls and glute bridges on leg days.", }, { "role": "assistant", - "content": "I'll remember turmeric milk and magnesium as part of your leg day recovery." - }, - { - "role": "user", - "content": "Last push day, I did 3x8 bench press at 60kg, 4x12 overhead press, and dips. Felt fatigued after." - }, - { - "role": "assistant", - "content": "Push day logged — 60kg bench, overhead press, dips. You felt fatigued afterward." - }, - { - "role": "user", - "content": "I prefer light dinners post-workout like tofu, soup, and vegetables." - }, - { - "role": "assistant", - "content": "Got it — light dinners post-workout: tofu, soup, and veggies." - }, - { - "role": "user", - "content": "I have mild lactose intolerance, so I avoid dairy. I use almond milk or lactose-free whey." - }, - { - "role": "assistant", - "content": "Understood — avoiding regular dairy, using almond milk and lactose-free whey." - }, - { - "role": "user", - "content": "I get occasional knee pain, so I avoid deep squats and do more hamstring curls and glute bridges on leg days." - }, - { - "role": "assistant", - "content": "Noted — due to knee discomfort, you substitute deep squats with curls and glute bridges." - }, - { - "role": "user", - "content": "I track sleep and notice poor performance when I sleep less than 6 hours." - }, - { - "role": "assistant", - "content": "Logged — performance drops when you get under 6 hours of sleep." - }, - { - "role": "user", - "content": "I take magnesium supplements to help with muscle recovery and sleep quality." - }, - { - "role": "assistant", - "content": "Remembered — magnesium helps you with recovery and sleep." - }, - { - "role": "user", - "content": "I avoid caffeine after 4 PM because it affects my sleep." - }, - { - "role": "assistant", - "content": "Got it — you avoid caffeine post-4 PM to protect your sleep." + "content": "Noted — due to knee discomfort, you substitute deep squats with curls and glute bridges.", }, + {"role": "user", "content": "I track sleep and notice poor performance when I sleep less than 6 hours."}, + {"role": "assistant", "content": "Logged — performance drops when you get under 6 hours of sleep."}, + {"role": "user", "content": "I take magnesium supplements to help with muscle recovery and sleep quality."}, + {"role": "assistant", "content": "Remembered — magnesium helps you with recovery and sleep."}, + {"role": "user", "content": "I avoid caffeine after 4 PM because it affects my sleep."}, + {"role": "assistant", "content": "Got it — you avoid caffeine post-4 PM to protect your sleep."}, ] store_user_preferences(messages) diff --git a/examples/misc/healthcare_assistant_google_adk.py b/examples/misc/healthcare_assistant_google_adk.py index fa58ecca..0665ce89 100644 --- a/examples/misc/healthcare_assistant_google_adk.py +++ b/examples/misc/healthcare_assistant_google_adk.py @@ -1,9 +1,11 @@ import asyncio import warnings + from google.adk.agents import Agent -from google.adk.sessions import InMemorySessionService from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService from google.genai import types + from mem0 import MemoryClient warnings.filterwarnings("ignore", category=DeprecationWarning) @@ -19,14 +21,14 @@ def save_patient_info(information: str) -> dict: print(f"Storing patient information: {information[:30]}...") # Get user_id from session state or use default - user_id = getattr(save_patient_info, 'user_id', 'default_user') + user_id = getattr(save_patient_info, "user_id", "default_user") # Store in Mem0 - response = mem0_client.add( + mem0_client.add( [{"role": "user", "content": information}], user_id=user_id, run_id="healthcare_session", - metadata={"type": "patient_information"} + metadata={"type": "patient_information"}, ) return {"status": "success", "message": "Information saved"} @@ -37,7 +39,7 @@ def retrieve_patient_info(query: str) -> str: print(f"Searching for patient information: {query}") # Get user_id from session state or use default - user_id = getattr(retrieve_patient_info, 'user_id', 'default_user') + user_id = getattr(retrieve_patient_info, "user_id", "default_user") # Search Mem0 results = mem0_client.search( @@ -45,7 +47,7 @@ def retrieve_patient_info(query: str) -> str: user_id=user_id, run_id="healthcare_session", limit=5, - threshold=0.7 # Higher threshold for more relevant results + threshold=0.7, # Higher threshold for more relevant results ) if not results: @@ -65,7 +67,7 @@ def schedule_appointment(date: str, time: str, reason: str) -> dict: "status": "success", "appointment_id": appointment_id, "confirmation": f"Appointment scheduled for {date} at {time} for {reason}", - "message": "Please arrive 15 minutes early to complete paperwork." + "message": "Please arrive 15 minutes early to complete paperwork.", } @@ -89,7 +91,7 @@ IMPORTANT GUIDELINES: - For serious symptoms, always recommend consulting a healthcare professional. - Keep all patient information confidential. """, - tools=[save_patient_info, retrieve_patient_info, schedule_appointment] + tools=[save_patient_info, retrieve_patient_info, schedule_appointment], ) # Set Up Session and Runner @@ -101,18 +103,10 @@ USER_ID = "Alex" SESSION_ID = "session_001" # Create a session -session = session_service.create_session( - app_name=APP_NAME, - user_id=USER_ID, - session_id=SESSION_ID -) +session = session_service.create_session(app_name=APP_NAME, user_id=USER_ID, session_id=SESSION_ID) # Create the runner -runner = Runner( - agent=healthcare_agent, - app_name=APP_NAME, - session_service=session_service -) +runner = Runner(agent=healthcare_agent, app_name=APP_NAME, session_service=session_service) # Interact with the Healthcare Assistant @@ -121,21 +115,14 @@ async def call_agent_async(query, runner, user_id, session_id): print(f"\n>>> Patient: {query}") # Format the user's message - content = types.Content( - role='user', - parts=[types.Part(text=query)] - ) + content = types.Content(role="user", parts=[types.Part(text=query)]) # Set user_id for tools to access save_patient_info.user_id = user_id retrieve_patient_info.user_id = user_id # Run the agent - async for event in runner.run_async( - user_id=user_id, - session_id=session_id, - new_message=content - ): + async for event in runner.run_async(user_id=user_id, session_id=session_id, new_message=content): if event.is_final_response(): if event.content and event.content.parts: response = event.content.parts[0].text @@ -152,7 +139,7 @@ async def run_conversation(): "Hi, I'm Alex. I've been having headaches for the past week, and I have a penicillin allergy.", runner=runner, user_id=USER_ID, - session_id=SESSION_ID + session_id=SESSION_ID, ) # Request for health information @@ -160,7 +147,7 @@ async def run_conversation(): "Can you tell me more about what might be causing my headaches?", runner=runner, user_id=USER_ID, - session_id=SESSION_ID + session_id=SESSION_ID, ) # Schedule an appointment @@ -168,15 +155,12 @@ async def run_conversation(): "I think I should see a doctor. Can you help me schedule an appointment for next Monday at 2pm?", runner=runner, user_id=USER_ID, - session_id=SESSION_ID + session_id=SESSION_ID, ) # Test memory - should remember patient name, symptoms, and allergy await call_agent_async( - "What medications should I avoid for my headaches?", - runner=runner, - user_id=USER_ID, - session_id=SESSION_ID + "What medications should I avoid for my headaches?", runner=runner, user_id=USER_ID, session_id=SESSION_ID ) @@ -191,37 +175,28 @@ async def interactive_mode(): session_id = f"session_{hash(patient_id) % 1000:03d}" # Create session for this user - session = session_service.create_session( - app_name=APP_NAME, - user_id=patient_id, - session_id=session_id - ) + session_service.create_session(app_name=APP_NAME, user_id=patient_id, session_id=session_id) print(f"\nStarting conversation with patient ID: {patient_id}") print("Type your message and press Enter.") while True: user_input = input("\n>>> Patient: ").strip() - if user_input.lower() in ['exit', 'quit', 'bye']: + if user_input.lower() in ["exit", "quit", "bye"]: print("Ending conversation. Thank you!") break - await call_agent_async( - user_input, - runner=runner, - user_id=patient_id, - session_id=session_id - ) + await call_agent_async(user_input, runner=runner, user_id=patient_id, session_id=session_id) # Main execution if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description='Healthcare Assistant with Memory') - parser.add_argument('--demo', action='store_true', help='Run the demo conversation') - parser.add_argument('--interactive', action='store_true', help='Run in interactive mode') - parser.add_argument('--patient-id', type=str, default=USER_ID, help='Patient ID for the conversation') + parser = argparse.ArgumentParser(description="Healthcare Assistant with Memory") + parser.add_argument("--demo", action="store_true", help="Run the demo conversation") + parser.add_argument("--interactive", action="store_true", help="Run in interactive mode") + parser.add_argument("--patient-id", type=str, default=USER_ID, help="Patient ID for the conversation") args = parser.parse_args() if args.demo: @@ -231,5 +206,3 @@ if __name__ == "__main__": else: # Default to demo mode if no arguments provided asyncio.run(run_conversation()) - - diff --git a/examples/misc/movie_recommendation_grok3.py b/examples/misc/movie_recommendation_grok3.py index 80463c58..bb4edbc4 100644 --- a/examples/misc/movie_recommendation_grok3.py +++ b/examples/misc/movie_recommendation_grok3.py @@ -16,26 +16,21 @@ from mem0 import Memory # Configure Mem0 with Grok 3 and Qdrant config = { - "vector_store": { - "provider": "qdrant", - "config": { - "embedding_model_dims": 384 - } - }, + "vector_store": {"provider": "qdrant", "config": {"embedding_model_dims": 384}}, "llm": { "provider": "xai", "config": { "model": "grok-3-beta", "temperature": 0.1, "max_tokens": 2000, - } + }, }, "embedder": { "provider": "huggingface", "config": { "model": "all-MiniLM-L6-v2" # open embedding model - } - } + }, + }, } # Instantiate memory layer @@ -57,20 +52,14 @@ def recommend_movie_with_memory(user_id: str, user_query: str): prompt += f"\nPreviously, the user mentioned: {past_memories}" # Generate movie recommendation using Grok 3 - response = grok_client.chat.completions.create( - model="grok-3-beta", - messages=[ - {"role": "user", "content": prompt} - ] - ) + response = grok_client.chat.completions.create(model="grok-3-beta", messages=[{"role": "user", "content": prompt}]) recommendation = response.choices[0].message.content # Store conversation in memory memory.add( - [{"role": "user", "content": user_query}, - {"role": "assistant", "content": recommendation}], + [{"role": "user", "content": user_query}, {"role": "assistant", "content": recommendation}], user_id=user_id, - metadata={"category": "movie"} + metadata={"category": "movie"}, ) return recommendation @@ -81,10 +70,11 @@ if __name__ == "__main__": user_id = "arshi" recommend_movie_with_memory(user_id, "I'm looking for a movie to watch tonight. Any suggestions?") # OUTPUT: You have watched Intersteller last weekend and you don't like horror movies, maybe you can watch "Purple Hearts" today. - recommend_movie_with_memory(user_id, "Can we skip the tearjerkers? I really enjoyed Notting Hill and Crazy Rich Asians.") + recommend_movie_with_memory( + user_id, "Can we skip the tearjerkers? I really enjoyed Notting Hill and Crazy Rich Asians." + ) # OUTPUT: Got it — no sad endings! You might enjoy "The Proposal" or "Love, Rosie". They’re both light-hearted romcoms with happy vibes. recommend_movie_with_memory(user_id, "Any light-hearted movie I can watch after work today?") # OUTPUT: Since you liked Crazy Rich Asians and The Proposal, how about "The Intern" or "Isn’t It Romantic"? Both are upbeat, funny, and perfect for relaxing. recommend_movie_with_memory(user_id, "I’ve already watched The Intern. Something new maybe?") # OUTPUT: No problem! Try "Your Place or Mine" - romcoms that match your taste and are tear-free! - diff --git a/examples/misc/personal_assistant_agno.py b/examples/misc/personal_assistant_agno.py index 474115d1..22674898 100644 --- a/examples/misc/personal_assistant_agno.py +++ b/examples/misc/personal_assistant_agno.py @@ -23,8 +23,8 @@ agent = Agent( name="Personal Agent", model=OpenAIChat(id="gpt-4o"), description="You are a helpful personal agent that helps me with day to day activities." - "You can process both text and images.", - markdown=True + "You can process both text and images.", + markdown=True, ) @@ -35,24 +35,16 @@ def chat_user(user_input: str = None, user_id: str = "user_123", image_path: str base64_image = base64.b64encode(image_file.read()).decode("utf-8") # First: the text message - text_msg = { - "role": "user", - "content": user_input - } + text_msg = {"role": "user", "content": user_input} # Second: the image message image_msg = { "role": "user", - "content": { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{base64_image}" - } - } + "content": {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}, } # Send both as separate message objects - client.add([text_msg, image_msg], user_id=user_id, output_format='v1.1') + client.add([text_msg, image_msg], user_id=user_id, output_format="v1.1") print("✅ Image uploaded and stored in memory.") if user_input: @@ -92,10 +84,13 @@ print(chat_user("When is my test?", user_id=user_id)) # OUTPUT: Your pilot's test is on your birthday, which is in five days. You're turning 25! # Good luck with your preparations, and remember to take some time to relax amidst the studying. -print(chat_user("This is the picture of what I brought with me in the trip to Bahamas", - image_path="travel_items.jpeg", # this will be added to Mem0 memory - user_id=user_id)) -print(chat_user("hey can you quickly tell me if brought my sunglasses to my trip, not able to find", - user_id=user_id)) +print( + chat_user( + "This is the picture of what I brought with me in the trip to Bahamas", + image_path="travel_items.jpeg", # this will be added to Mem0 memory + user_id=user_id, + ) +) +print(chat_user("hey can you quickly tell me if brought my sunglasses to my trip, not able to find", user_id=user_id)) # OUTPUT: Yes, you did bring your sunglasses on your trip to the Bahamas along with your laptop, face masks and other items.. # Since you can't find them now, perhaps check the pockets of jackets you wore or in your luggage compartments. diff --git a/examples/misc/study_buddy.py b/examples/misc/study_buddy.py index e02ac07f..796cab8f 100644 --- a/examples/misc/study_buddy.py +++ b/examples/misc/study_buddy.py @@ -7,6 +7,7 @@ In order to run this file, you need to set up your Mem0 API at Mem0 platform and export OPENAI_API_KEY="your_openai_api_key" export MEM0_API_KEY="your_mem0_api_key" """ + import asyncio from agents import Agent, Runner @@ -23,25 +24,19 @@ study_agent = Agent( - Identify topics the user has struggled with (e.g., "I'm confused", "this is hard") - Help with spaced repetition by suggesting topics to revisit based on last review time - Personalize answers using stored memories -- Summarize PDFs or notes the user uploads""") +- Summarize PDFs or notes the user uploads""", +) # Upload and store PDF to Mem0 def upload_pdf(pdf_url: str, user_id: str): - pdf_message = { - "role": "user", - "content": { - "type": "pdf_url", - "pdf_url": {"url": pdf_url} - } - } + pdf_message = {"role": "user", "content": {"type": "pdf_url", "pdf_url": {"url": pdf_url}}} client.add([pdf_message], user_id=user_id) print("✅ PDF uploaded and processed into memory.") # Main interaction loop with your personal study buddy async def study_buddy(user_id: str, topic: str, user_input: str): - memories = client.search(f"{topic}", user_id=user_id) memory_context = "n".join(f"- {m['memory']}" for m in memories) @@ -56,9 +51,11 @@ Now respond to the user's new question or comment: result = await Runner.run(study_agent, prompt) response = result.final_output - client.add([ - {"role": "user", "content": f'''Topic: {topic}nUser: {user_input}nnStudy Assistant: {response}'''} - ], user_id=user_id, metadata={"topic": topic}) + client.add( + [{"role": "user", "content": f"""Topic: {topic}nUser: {user_input}nnStudy Assistant: {response}"""}], + user_id=user_id, + metadata={"topic": topic}, + ) return response @@ -78,7 +75,12 @@ async def main(): # Demonstrate spaced repetition prompting topic = "Momentum Conservation" - print(await study_buddy(user_id, topic, "I think we covered this last week. Is it time to review momentum conservation again?")) + print( + await study_buddy( + user_id, topic, "I think we covered this last week. Is it time to review momentum conservation again?" + ) + ) + if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/misc/voice_assistant_elevenlabs.py b/examples/misc/voice_assistant_elevenlabs.py index 279e5b6b..51bcf618 100644 --- a/examples/misc/voice_assistant_elevenlabs.py +++ b/examples/misc/voice_assistant_elevenlabs.py @@ -57,7 +57,7 @@ def initialize_memory(): }, { "role": "user", - "content": "I prefer brief and concise responses without unnecessary explanations. I get frustrated when assistants are too wordy or repeat information I already know." + "content": "I prefer brief and concise responses without unnecessary explanations. I get frustrated when assistants are too wordy or repeat information I already know.", }, { "role": "assistant", @@ -65,7 +65,7 @@ def initialize_memory(): }, { "role": "user", - "content": "I like to listen to jazz music when I'm working, especially artists like Miles Davis and John Coltrane. I find it helps me focus and be more productive." + "content": "I like to listen to jazz music when I'm working, especially artists like Miles Davis and John Coltrane. I find it helps me focus and be more productive.", }, { "role": "assistant", @@ -73,7 +73,7 @@ def initialize_memory(): }, { "role": "user", - "content": "I usually wake up at 7 AM and prefer reminders for meetings 30 minutes in advance. My most productive hours are between 9 AM and noon, so I try to schedule important tasks during that time." + "content": "I usually wake up at 7 AM and prefer reminders for meetings 30 minutes in advance. My most productive hours are between 9 AM and noon, so I try to schedule important tasks during that time.", }, { "role": "assistant", @@ -81,7 +81,7 @@ def initialize_memory(): }, { "role": "user", - "content": "My favorite color is navy blue, and I prefer dark mode in all my apps. I'm allergic to peanuts, so please remind me to check ingredients when I ask about recipes or restaurants." + "content": "My favorite color is navy blue, and I prefer dark mode in all my apps. I'm allergic to peanuts, so please remind me to check ingredients when I ask about recipes or restaurants.", }, { "role": "assistant", @@ -89,7 +89,7 @@ def initialize_memory(): }, { "role": "user", - "content": "My partner's name is Jamie, and we have a golden retriever named Max who is 3 years old. My parents live in Chicago, and I try to visit them once every two months." + "content": "My partner's name is Jamie, and we have a golden retriever named Max who is 3 years old. My parents live in Chicago, and I try to visit them once every two months.", }, { "role": "assistant", @@ -135,11 +135,11 @@ def record_audio(filename="input.wav", record_seconds=5): stream.close() p.terminate() - with wave.open(filename, 'wb') as wf: + with wave.open(filename, "wb") as wf: wf.setnchannels(channels) wf.setsampwidth(p.get_sample_size(fmt)) wf.setframerate(rate) - wf.writeframes(b''.join(frames)) + wf.writeframes(b"".join(frames)) # ------------------ STT USING WHISPER ------------------ @@ -147,10 +147,7 @@ def transcribe_whisper(audio_path): print("🔎 Transcribing with Whisper...") try: with open(audio_path, "rb") as audio_file: - transcript = openai_client.audio.transcriptions.create( - model="whisper-1", - file=audio_file - ) + transcript = openai_client.audio.transcriptions.create(model="whisper-1", file=audio_file) print(f"🗣️ You said: {transcript.text}") return transcript.text except Exception as e: @@ -165,9 +162,7 @@ def get_agent_response(user_input): try: task = Task( - description=f"Respond to: {user_input}", - expected_output="A short and relevant reply.", - agent=voice_agent + description=f"Respond to: {user_input}", expected_output="A short and relevant reply.", agent=voice_agent ) crew = Crew( agents=[voice_agent], @@ -175,22 +170,19 @@ def get_agent_response(user_input): process=Process.sequential, verbose=True, memory=True, - memory_config={ - "provider": "mem0", - "config": {"user_id": USER_ID} - } + memory_config={"provider": "mem0", "config": {"user_id": USER_ID}}, ) result = crew.kickoff() # Extract the text response from the complex result object - if hasattr(result, 'raw'): + if hasattr(result, "raw"): return result.raw - elif isinstance(result, dict) and 'raw' in result: - return result['raw'] - elif isinstance(result, dict) and 'tasks_output' in result: - outputs = result['tasks_output'] + elif isinstance(result, dict) and "raw" in result: + return result["raw"] + elif isinstance(result, dict) and "tasks_output" in result: + outputs = result["tasks_output"] if outputs and isinstance(outputs, list) and len(outputs) > 0: - return outputs[0].get('raw', str(result)) + return outputs[0].get("raw", str(result)) # Fallback to string representation if we can't extract the raw response return str(result) @@ -204,10 +196,7 @@ def get_agent_response(user_input): def speak_response(text): print(f"🤖 Agent: {text}") audio = tts_client.text_to_speech.convert( - text=text, - voice_id="JBFqnCBsd6RMkjVDRZzb", - model_id="eleven_multilingual_v2", - output_format="mp3_44100_128" + text=text, voice_id="JBFqnCBsd6RMkjVDRZzb", model_id="eleven_multilingual_v2", output_format="mp3_44100_128" ) play(audio) @@ -220,7 +209,7 @@ def run_voice_agent(): record_audio(tmp_audio.name) try: user_text = transcribe_whisper(tmp_audio.name) - if user_text.lower() in ['exit', 'quit', 'stop']: + if user_text.lower() in ["exit", "quit", "stop"]: print("👋 Exiting.") break response = get_agent_response(user_text) diff --git a/mem0/client/main.py b/mem0/client/main.py index 891e64a0..8f69fd04 100644 --- a/mem0/client/main.py +++ b/mem0/client/main.py @@ -95,10 +95,7 @@ class MemoryClient: self.client = client # Ensure the client has the correct base_url and headers self.client.base_url = httpx.URL(self.host) - self.client.headers.update({ - "Authorization": f"Token {self.api_key}", - "Mem0-User-ID": self.user_id - }) + self.client.headers.update({"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id}) else: self.client = httpx.Client( base_url=self.host, @@ -237,7 +234,9 @@ class MemoryClient: response.raise_for_status() if "metadata" in kwargs: del kwargs["metadata"] - capture_client_event("client.search", self, {"api_version": version, "keys": list(kwargs.keys()), "sync_type": "sync"}) + capture_client_event( + "client.search", self, {"api_version": version, "keys": list(kwargs.keys()), "sync_type": "sync"} + ) return response.json() @api_error_handler @@ -357,10 +356,7 @@ class MemoryClient: else: entities = self.users() # Filter entities based on provided IDs using list comprehension - to_delete = [ - {"type": entity["type"], "name": entity["name"]} - for entity in entities["results"] - ] + to_delete = [{"type": entity["type"], "name": entity["name"]} for entity in entities["results"]] params = self._prepare_params() @@ -373,7 +369,9 @@ class MemoryClient: response.raise_for_status() capture_client_event( - "client.delete_users", self, {"user_id": user_id, "agent_id": agent_id, "app_id": app_id, "run_id": run_id, "sync_type": "sync"} + "client.delete_users", + self, + {"user_id": user_id, "agent_id": agent_id, "app_id": app_id, "run_id": run_id, "sync_type": "sync"}, ) return { "message": "Entity deleted successfully." @@ -454,7 +452,9 @@ class MemoryClient: """ response = self.client.post("/v1/exports/", json={"schema": schema, **self._prepare_params(kwargs)}) response.raise_for_status() - capture_client_event("client.create_memory_export", self, {"schema": schema, "keys": list(kwargs.keys()), "sync_type": "sync"}) + capture_client_event( + "client.create_memory_export", self, {"schema": schema, "keys": list(kwargs.keys()), "sync_type": "sync"} + ) return response.json() @api_error_handler @@ -527,7 +527,11 @@ class MemoryClient: ) payload = self._prepare_params( - {"custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria} + { + "custom_instructions": custom_instructions, + "custom_categories": custom_categories, + "retrieval_criteria": retrieval_criteria, + } ) response = self.client.patch( f"/api/v1/orgs/organizations/{self.org_id}/projects/{self.project_id}/", @@ -537,7 +541,12 @@ class MemoryClient: capture_client_event( "client.update_project", self, - {"custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria, "sync_type": "sync"}, + { + "custom_instructions": custom_instructions, + "custom_categories": custom_categories, + "retrieval_criteria": retrieval_criteria, + "sync_type": "sync", + }, ) return response.json() @@ -750,10 +759,7 @@ class AsyncMemoryClient: self.async_client = client # Ensure the client has the correct base_url and headers self.async_client.base_url = httpx.URL(self.host) - self.async_client.headers.update({ - "Authorization": f"Token {self.api_key}", - "Mem0-User-ID": self.user_id - }) + self.async_client.headers.update({"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id}) else: self.async_client = httpx.AsyncClient( base_url=self.host, @@ -768,7 +774,11 @@ class AsyncMemoryClient: """Validate the API key by making a test request.""" try: params = self._prepare_params() - response = requests.get(f"{self.host}/v1/ping/", headers={"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id}, params=params) + response = requests.get( + f"{self.host}/v1/ping/", + headers={"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id}, + params=params, + ) data = response.json() response.raise_for_status() @@ -973,10 +983,7 @@ class AsyncMemoryClient: else: entities = await self.users() # Filter entities based on provided IDs using list comprehension - to_delete = [ - {"type": entity["type"], "name": entity["name"]} - for entity in entities["results"] - ] + to_delete = [{"type": entity["type"], "name": entity["name"]} for entity in entities["results"]] params = self._prepare_params() @@ -988,7 +995,11 @@ class AsyncMemoryClient: response = await self.async_client.delete(f"/v2/entities/{entity['type']}/{entity['name']}/", params=params) response.raise_for_status() - capture_client_event("client.delete_users", self, {"user_id": user_id, "agent_id": agent_id, "app_id": app_id, "run_id": run_id, "sync_type": "async"}) + capture_client_event( + "client.delete_users", + self, + {"user_id": user_id, "agent_id": agent_id, "app_id": app_id, "run_id": run_id, "sync_type": "async"}, + ) return { "message": "Entity deleted successfully." if (user_id or agent_id or app_id or run_id) @@ -1091,8 +1102,10 @@ class AsyncMemoryClient: @api_error_handler async def update_project( - self, custom_instructions: Optional[str] = None, custom_categories: Optional[List[str]] = None, - retrieval_criteria: Optional[List[Dict[str, Any]]] = None + self, + custom_instructions: Optional[str] = None, + custom_categories: Optional[List[str]] = None, + retrieval_criteria: Optional[List[Dict[str, Any]]] = None, ) -> Dict[str, Any]: if not (self.org_id and self.project_id): raise ValueError("org_id and project_id must be set to update instructions or categories") @@ -1103,7 +1116,11 @@ class AsyncMemoryClient: ) payload = self._prepare_params( - {"custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria} + { + "custom_instructions": custom_instructions, + "custom_categories": custom_categories, + "retrieval_criteria": retrieval_criteria, + } ) response = await self.async_client.patch( f"/api/v1/orgs/organizations/{self.org_id}/projects/{self.project_id}/", @@ -1113,7 +1130,12 @@ class AsyncMemoryClient: capture_client_event( "client.update_project", self, - {"custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria, "sync_type": "async"}, + { + "custom_instructions": custom_instructions, + "custom_categories": custom_categories, + "retrieval_criteria": retrieval_criteria, + "sync_type": "async", + }, ) return response.json() @@ -1174,4 +1196,3 @@ class AsyncMemoryClient: response.raise_for_status() capture_client_event("client.feedback", self, data, {"sync_type": "async"}) return response.json() - diff --git a/mem0/configs/vector_stores/opensearch.py b/mem0/configs/vector_stores/opensearch.py index 8f158277..1afe6cf3 100644 --- a/mem0/configs/vector_stores/opensearch.py +++ b/mem0/configs/vector_stores/opensearch.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Union, Type +from typing import Any, Dict, Optional, Type, Union from pydantic import BaseModel, Field, model_validator @@ -7,33 +7,17 @@ class OpenSearchConfig(BaseModel): collection_name: str = Field("mem0", description="Name of the index") host: str = Field("localhost", description="OpenSearch host") port: int = Field(9200, description="OpenSearch port") - user: Optional[str] = Field( - None, description="Username for authentication" - ) - password: Optional[str] = Field( - None, description="Password for authentication" - ) - api_key: Optional[str] = Field( - None, description="API key for authentication (if applicable)" - ) - embedding_model_dims: int = Field( - 1536, description="Dimension of the embedding vector" - ) - verify_certs: bool = Field( - False, description="Verify SSL certificates (default False for OpenSearch)" - ) - use_ssl: bool = Field( - False, description="Use SSL for connection (default False for OpenSearch)" - ) - http_auth: Optional[object] = Field( - None, description="HTTP authentication method / AWS SigV4" - ) + user: Optional[str] = Field(None, description="Username for authentication") + password: Optional[str] = Field(None, description="Password for authentication") + api_key: Optional[str] = Field(None, description="API key for authentication (if applicable)") + embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector") + verify_certs: bool = Field(False, description="Verify SSL certificates (default False for OpenSearch)") + use_ssl: bool = Field(False, description="Use SSL for connection (default False for OpenSearch)") + http_auth: Optional[object] = Field(None, description="HTTP authentication method / AWS SigV4") connection_class: Optional[Union[str, Type]] = Field( "RequestsHttpConnection", description="Connection class for OpenSearch" ) - pool_maxsize: int = Field( - 20, description="Maximum number of connections in the pool" - ) + pool_maxsize: int = Field(20, description="Maximum number of connections in the pool") @model_validator(mode="before") @classmethod @@ -41,7 +25,7 @@ class OpenSearchConfig(BaseModel): # Check if host is provided if not values.get("host"): raise ValueError("Host must be provided for OpenSearch") - + return values @model_validator(mode="before") @@ -52,7 +36,6 @@ class OpenSearchConfig(BaseModel): extra_fields = input_fields - allowed_fields if extra_fields: raise ValueError( - f"Extra fields not allowed: {', '.join(extra_fields)}. " - f"Allowed fields: {', '.join(allowed_fields)}" + f"Extra fields not allowed: {', '.join(extra_fields)}. " f"Allowed fields: {', '.join(allowed_fields)}" ) return values diff --git a/mem0/embeddings/aws_bedrock.py b/mem0/embeddings/aws_bedrock.py index 10116511..807764c1 100644 --- a/mem0/embeddings/aws_bedrock.py +++ b/mem0/embeddings/aws_bedrock.py @@ -23,12 +23,12 @@ class AWSBedrockEmbedding(EmbeddingBase): super().__init__(config) self.config.model = self.config.model or "amazon.titan-embed-text-v1" - + # Get AWS config from environment variables or use defaults aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "") aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "") aws_region = os.environ.get("AWS_REGION", "us-west-2") - + # Check if AWS config is provided in the config if hasattr(self.config, "aws_access_key_id"): aws_access_key = self.config.aws_access_key_id @@ -36,7 +36,7 @@ class AWSBedrockEmbedding(EmbeddingBase): aws_secret_key = self.config.aws_secret_access_key if hasattr(self.config, "aws_region"): aws_region = self.config.aws_region - + self.client = boto3.client( "bedrock-runtime", region_name=aws_region, diff --git a/mem0/embeddings/huggingface.py b/mem0/embeddings/huggingface.py index 31fcafd7..934c69ad 100644 --- a/mem0/embeddings/huggingface.py +++ b/mem0/embeddings/huggingface.py @@ -11,6 +11,7 @@ logging.getLogger("transformers").setLevel(logging.WARNING) logging.getLogger("sentence_transformers").setLevel(logging.WARNING) logging.getLogger("huggingface_hub").setLevel(logging.WARNING) + class HuggingFaceEmbedding(EmbeddingBase): def __init__(self, config: Optional[BaseEmbedderConfig] = None): super().__init__(config) diff --git a/mem0/graphs/configs.py b/mem0/graphs/configs.py index bbfcca8a..e8e0fd45 100644 --- a/mem0/graphs/configs.py +++ b/mem0/graphs/configs.py @@ -22,7 +22,8 @@ class Neo4jConfig(BaseModel): if not url or not username or not password: raise ValueError("Please provide 'url', 'username' and 'password'.") return values - + + class MemgraphConfig(BaseModel): url: Optional[str] = Field(None, description="Host address for the graph database") username: Optional[str] = Field(None, description="Username for the graph database") diff --git a/mem0/llms/aws_bedrock.py b/mem0/llms/aws_bedrock.py index adf03762..dde66f30 100644 --- a/mem0/llms/aws_bedrock.py +++ b/mem0/llms/aws_bedrock.py @@ -20,18 +20,19 @@ def extract_provider(model: str) -> str: return provider raise ValueError(f"Unknown provider in model: {model}") + class AWSBedrockLLM(LLMBase): def __init__(self, config: Optional[BaseLlmConfig] = None): super().__init__(config) if not self.config.model: self.config.model = "anthropic.claude-3-5-sonnet-20240620-v1:0" - + # Get AWS config from environment variables or use defaults aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "") aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "") aws_region = os.environ.get("AWS_REGION", "us-west-2") - + # Check if AWS config is provided in the config if hasattr(self.config, "aws_access_key_id"): aws_access_key = self.config.aws_access_key_id @@ -39,14 +40,14 @@ class AWSBedrockLLM(LLMBase): aws_secret_key = self.config.aws_secret_access_key if hasattr(self.config, "aws_region"): aws_region = self.config.aws_region - + self.client = boto3.client( "bedrock-runtime", region_name=aws_region, aws_access_key_id=aws_access_key if aws_access_key else None, aws_secret_access_key=aws_secret_key if aws_secret_key else None, ) - + self.model_kwargs = { "temperature": self.config.temperature, "max_tokens_to_sample": self.config.max_tokens, @@ -145,7 +146,9 @@ class AWSBedrockLLM(LLMBase): input_body = { "inputText": prompt, "textGenerationConfig": { - "maxTokenCount": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000, + "maxTokenCount": self.model_kwargs["max_tokens_to_sample"] + or self.model_kwargs["max_tokens"] + or 5000, "topP": self.model_kwargs["top_p"] or 0.9, "temperature": self.model_kwargs["temperature"] or 0.1, }, @@ -243,22 +246,15 @@ class AWSBedrockLLM(LLMBase): body = json.dumps(input_body) if provider == "anthropic" or provider == "deepseek": - input_body = { - "messages": [ - { - "role": "user", - "content": [{"type": "text", "text": prompt}] - } - ], + "messages": [{"role": "user", "content": [{"type": "text", "text": prompt}]}], "max_tokens": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000, "temperature": self.model_kwargs["temperature"] or 0.1, "top_p": self.model_kwargs["top_p"] or 0.9, "anthropic_version": "bedrock-2023-05-31", } - + body = json.dumps(input_body) - response = self.client.invoke_model( body=body, @@ -272,6 +268,6 @@ class AWSBedrockLLM(LLMBase): modelId=self.config.model, accept="application/json", contentType="application/json", - ) + ) return self._parse_response(response, tools) diff --git a/mem0/memory/main.py b/mem0/memory/main.py index 3744724d..585b73b2 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -34,17 +34,17 @@ from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory def _build_filters_and_metadata( - *, # Enforce keyword-only arguments + *, # Enforce keyword-only arguments user_id: Optional[str] = None, agent_id: Optional[str] = None, run_id: Optional[str] = None, - actor_id: Optional[str] = None, # For query-time filtering + actor_id: Optional[str] = None, # For query-time filtering input_metadata: Optional[Dict[str, Any]] = None, input_filters: Optional[Dict[str, Any]] = None, ) -> tuple[Dict[str, Any], Dict[str, Any]]: """ Constructs metadata for storage and filters for querying based on session and actor identifiers. - + This helper ties every memory/query to exactly one session id (`user_id`, `agent_id`, or `run_id`) and optionally narrows queries to a specific `actor_id`. It returns two dicts: @@ -78,10 +78,10 @@ def _build_filters_and_metadata( - effective_query_filters (Dict[str, Any]): Filters for querying memories, scoped to the determined session and potentially a resolved actor. """ - + base_metadata_template = deepcopy(input_metadata) if input_metadata else {} effective_query_filters = deepcopy(input_filters) if input_filters else {} - + # ---------- resolve session id (mandatory) ---------- session_key, session_val = None, None if user_id: @@ -90,20 +90,20 @@ def _build_filters_and_metadata( session_key, session_val = "agent_id", agent_id elif run_id: session_key, session_val = "run_id", run_id - + if session_key is None: raise ValueError("One of 'user_id', 'agent_id', or 'run_id' must be provided.") - + base_metadata_template[session_key] = session_val effective_query_filters[session_key] = session_val - + # ---------- optional actor filter ---------- resolved_actor_id = actor_id or effective_query_filters.get("actor_id") if resolved_actor_id: effective_query_filters["actor_id"] = resolved_actor_id - + return base_metadata_template, effective_query_filters - + setup_config() logger = logging.getLogger(__name__) @@ -189,7 +189,7 @@ class Memory(MemoryBase): ): """ Create a new memory. - + Adds new memories scoped to a single session id (e.g. `user_id`, `agent_id`, or `run_id`). One of those ids is required. Args: @@ -208,7 +208,7 @@ class Memory(MemoryBase): creating procedural memories (typically requires 'agent_id'). Otherwise, memories are treated as general conversational/factual memories.memory_type (str, optional): Type of memory to create. Defaults to None. By default, it creates the short term memories and long term (semantic and episodic) memories. Pass "procedural_memory" to create procedural memories. prompt (str, optional): Prompt to use for the memory creation. Defaults to None. - + Returns: dict: A dictionary containing the result of the memory addition operation, typically @@ -216,14 +216,14 @@ class Memory(MemoryBase): and potentially "relations" if graph store is enabled. Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "event": "ADD"}]}` """ - + processed_metadata, effective_filters = _build_filters_and_metadata( user_id=user_id, agent_id=agent_id, run_id=run_id, input_metadata=metadata, ) - + if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value: raise ValueError( f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} to create procedural memories." @@ -231,10 +231,10 @@ class Memory(MemoryBase): if isinstance(messages, str): messages = [{"role": "user", "content": messages}] - + elif isinstance(messages, dict): messages = [messages] - + elif not isinstance(messages, list): raise ValueError("messages must be str, dict, or list[dict]") @@ -255,7 +255,7 @@ class Memory(MemoryBase): vector_store_result = future1.result() graph_result = future2.result() - + if self.api_version == "v1.0": warnings.warn( "The current add API output format is deprecated. " @@ -277,21 +277,21 @@ class Memory(MemoryBase): def _add_to_vector_store(self, messages, metadata, filters, infer): if not infer: returned_memories = [] - for message_dict in messages: - if not isinstance(message_dict, dict) or \ - message_dict.get("role") is None or \ - message_dict.get("content") is None: + for message_dict in messages: + if ( + not isinstance(message_dict, dict) + or message_dict.get("role") is None + or message_dict.get("content") is None + ): logger.warning(f"Skipping invalid message format: {message_dict}") continue if message_dict["role"] == "system": - continue + continue - per_msg_meta = deepcopy(metadata) per_msg_meta["role"] = message_dict["role"] - actor_name = message_dict.get("name") if actor_name: per_msg_meta["actor_id"] = actor_name @@ -311,8 +311,8 @@ class Memory(MemoryBase): ) return returned_memories - parsed_messages = parse_messages(messages) - + parsed_messages = parse_messages(messages) + if self.config.custom_fact_extraction_prompt: system_prompt = self.config.custom_fact_extraction_prompt user_prompt = f"Input:\n{parsed_messages}" @@ -336,7 +336,7 @@ class Memory(MemoryBase): retrieved_old_memory = [] new_message_embeddings = {} - for new_mem in new_retrieved_facts: + for new_mem in new_retrieved_facts: messages_embeddings = self.embedding_model.embed(new_mem, "add") new_message_embeddings[new_mem] = messages_embeddings existing_memories = self.vector_store.search( @@ -347,7 +347,7 @@ class Memory(MemoryBase): ) for mem in existing_memories: retrieved_old_memory.append({"id": mem.id, "text": mem.payload["data"]}) - + unique_data = {} for item in retrieved_old_memory: unique_data[item["id"]] = item @@ -389,7 +389,7 @@ class Memory(MemoryBase): if not action_text: logging.info("Skipping memory entry because of empty `text` field.") continue - + event_type = resp.get("event") if event_type == "ADD": memory_id = self._create_memory( @@ -405,16 +405,23 @@ class Memory(MemoryBase): existing_embeddings=new_message_embeddings, metadata=deepcopy(metadata), ) - returned_memories.append({ - "id": temp_uuid_mapping[resp.get("id")], "memory": action_text, - "event": event_type, "previous_memory": resp.get("old_memory"), - }) + returned_memories.append( + { + "id": temp_uuid_mapping[resp.get("id")], + "memory": action_text, + "event": event_type, + "previous_memory": resp.get("old_memory"), + } + ) elif event_type == "DELETE": self._delete_memory(memory_id=temp_uuid_mapping[resp.get("id")]) - returned_memories.append({ - "id": temp_uuid_mapping[resp.get("id")], "memory": action_text, - "event": event_type, - }) + returned_memories.append( + { + "id": temp_uuid_mapping[resp.get("id")], + "memory": action_text, + "event": event_type, + } + ) elif event_type == "NONE": logging.info("NOOP for Memory.") except Exception as e: @@ -462,11 +469,8 @@ class Memory(MemoryBase): "actor_id", "role", ] - - core_and_promoted_keys = { - "data", "hash", "created_at", "updated_at", "id", - *promoted_payload_keys - } + + core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys} result_item = MemoryItem( id=memory.id, @@ -479,18 +483,16 @@ class Memory(MemoryBase): for key in promoted_payload_keys: if key in memory.payload: result_item[key] = memory.payload[key] - - additional_metadata = { - k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys - } + + additional_metadata = {k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys} if additional_metadata: result_item["metadata"] = additional_metadata - + return result_item def get_all( self, - *, + *, user_id: Optional[str] = None, agent_id: Optional[str] = None, run_id: Optional[str] = None, @@ -505,7 +507,7 @@ class Memory(MemoryBase): agent_id (str, optional): agent id run_id (str, optional): run id filters (dict, optional): Additional custom key-value filters to apply to the search. - These are merged with the ID-based scoping filters. For example, + These are merged with the ID-based scoping filters. For example, `filters={"actor_id": "some_user"}`. limit (int, optional): The maximum number of memories to return. Defaults to 100. @@ -515,21 +517,16 @@ class Memory(MemoryBase): it might return a direct list (see deprecation warning). Example for v1.1+: `{"results": [{"id": "...", "memory": "...", ...}]}` """ - + _, effective_filters = _build_filters_and_metadata( - user_id=user_id, - agent_id=agent_id, - run_id=run_id, - input_filters=filters + user_id=user_id, agent_id=agent_id, run_id=run_id, input_filters=filters ) - + if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")): raise ValueError("At least one of 'user_id', 'agent_id', or 'run_id' must be specified.") capture_event( - "mem0.get_all", - self, - {"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "sync"} + "mem0.get_all", self, {"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "sync"} ) with concurrent.futures.ThreadPoolExecutor() as executor: @@ -542,9 +539,9 @@ class Memory(MemoryBase): [future_memories, future_graph_entities] if future_graph_entities else [future_memories] ) - all_memories_result = future_memories.result() + all_memories_result = future_memories.result() graph_entities_result = future_graph_entities.result() if future_graph_entities else None - + if self.enable_graph: return {"results": all_memories_result, "relations": graph_entities_result} @@ -556,26 +553,27 @@ class Memory(MemoryBase): category=DeprecationWarning, stacklevel=2, ) - return all_memories_result + return all_memories_result else: return {"results": all_memories_result} def _get_all_from_vector_store(self, filters, limit): memories_result = self.vector_store.list(filters=filters, limit=limit) - actual_memories = memories_result[0] if isinstance(memories_result, tuple) and len(memories_result) > 0 else memories_result + actual_memories = ( + memories_result[0] if isinstance(memories_result, tuple) and len(memories_result) > 0 else memories_result + ) promoted_payload_keys = [ - "user_id", "agent_id", "run_id", + "user_id", + "agent_id", + "run_id", "actor_id", "role", ] - core_and_promoted_keys = { - "data", "hash", "created_at", "updated_at", "id", - *promoted_payload_keys - } + core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys} formatted_memories = [] - for mem in actual_memories: + for mem in actual_memories: memory_item_dict = MemoryItem( id=mem.id, memory=mem.payload["data"], @@ -587,15 +585,13 @@ class Memory(MemoryBase): for key in promoted_payload_keys: if key in mem.payload: memory_item_dict[key] = mem.payload[key] - - additional_metadata = { - k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys - } + + additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys} if additional_metadata: memory_item_dict["metadata"] = additional_metadata - + formatted_memories.append(memory_item_dict) - + return formatted_memories def search( @@ -624,12 +620,9 @@ class Memory(MemoryBase): Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "score": 0.8, ...}]}` """ _, effective_filters = _build_filters_and_metadata( - user_id=user_id, - agent_id=agent_id, - run_id=run_id, - input_filters=filters + user_id=user_id, agent_id=agent_id, run_id=run_id, input_filters=filters ) - + if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")): raise ValueError("At least one of 'user_id', 'agent_id', or 'run_id' must be specified.") @@ -651,7 +644,7 @@ class Memory(MemoryBase): original_memories = future_memories.result() graph_entities = future_graph_entities.result() if future_graph_entities else None - + if self.enable_graph: return {"results": original_memories, "relations": graph_entities} @@ -678,11 +671,8 @@ class Memory(MemoryBase): "actor_id", "role", ] - - core_and_promoted_keys = { - "data", "hash", "created_at", "updated_at", "id", - *promoted_payload_keys - } + + core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys} original_memories = [] for mem in memories: @@ -693,18 +683,16 @@ class Memory(MemoryBase): created_at=mem.payload.get("created_at"), updated_at=mem.payload.get("updated_at"), score=mem.score, - ).model_dump() + ).model_dump() for key in promoted_payload_keys: if key in mem.payload: memory_item_dict[key] = mem.payload[key] - - additional_metadata = { - k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys - } + + additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys} if additional_metadata: memory_item_dict["metadata"] = additional_metadata - + original_memories.append(memory_item_dict) return original_memories @@ -738,7 +726,7 @@ class Memory(MemoryBase): self._delete_memory(memory_id) return {"message": "Memory deleted successfully!"} - def delete_all(self, user_id:Optional[str]=None, agent_id:Optional[str]=None, run_id:Optional[str]=None): + def delete_all(self, user_id: Optional[str] = None, agent_id: Optional[str] = None, run_id: Optional[str] = None): """ Delete all memories. @@ -860,11 +848,11 @@ class Memory(MemoryBase): except Exception: logger.error(f"Error getting memory with ID {memory_id} during update.") raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'") - + prev_value = existing_memory.payload.get("data") new_metadata = deepcopy(metadata) if metadata is not None else {} - + new_metadata["data"] = data new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest() new_metadata["created_at"] = existing_memory.payload.get("created_at") @@ -875,7 +863,7 @@ class Memory(MemoryBase): if "agent_id" in existing_memory.payload: new_metadata["agent_id"] = existing_memory.payload["agent_id"] if "run_id" in existing_memory.payload: - new_metadata["run_id"] = existing_memory.payload["run_id"] + new_metadata["run_id"] = existing_memory.payload["run_id"] if "actor_id" in existing_memory.payload: new_metadata["actor_id"] = existing_memory.payload["actor_id"] if "role" in existing_memory.payload: @@ -885,14 +873,14 @@ class Memory(MemoryBase): embeddings = existing_embeddings[data] else: embeddings = self.embedding_model.embed(data, "update") - + self.vector_store.update( vector_id=memory_id, vector=embeddings, payload=new_metadata, ) logger.info(f"Updating memory with ID {memory_id=} with {data=}") - + self.db.add_history( memory_id, prev_value, @@ -1037,12 +1025,9 @@ class AsyncMemory(MemoryBase): dict: A dictionary containing the result of the memory addition operation. """ processed_metadata, effective_filters = _build_filters_and_metadata( - user_id=user_id, - agent_id=agent_id, - run_id=run_id, - input_metadata=metadata + user_id=user_id, agent_id=agent_id, run_id=run_id, input_metadata=metadata ) - + if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value: raise ValueError( f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} to create procedural memories." @@ -1050,15 +1035,17 @@ class AsyncMemory(MemoryBase): if isinstance(messages, str): messages = [{"role": "user", "content": messages}] - + elif isinstance(messages, dict): messages = [messages] - + elif not isinstance(messages, list): raise ValueError("messages must be str, dict, or list[dict]") if agent_id is not None and memory_type == MemoryType.PROCEDURAL.value: - results = await self._create_procedural_memory(messages, metadata=processed_metadata, prompt=prompt, llm=llm) + results = await self._create_procedural_memory( + messages, metadata=processed_metadata, prompt=prompt, llm=llm + ) return results if self.config.llm.config.get("enable_vision"): @@ -1066,7 +1053,9 @@ class AsyncMemory(MemoryBase): else: messages = parse_vision_messages(messages) - vector_store_task = asyncio.create_task(self._add_to_vector_store(messages, processed_metadata, effective_filters, infer)) + vector_store_task = asyncio.create_task( + self._add_to_vector_store(messages, processed_metadata, effective_filters, infer) + ) graph_task = asyncio.create_task(self._add_to_graph(messages, effective_filters)) vector_store_result, graph_result = await asyncio.gather(vector_store_task, graph_task) @@ -1090,8 +1079,8 @@ class AsyncMemory(MemoryBase): return {"results": vector_store_result} async def _add_to_vector_store( - self, - messages: list, + self, + messages: list, metadata: dict, filters: dict, infer: bool, @@ -1099,9 +1088,11 @@ class AsyncMemory(MemoryBase): if not infer: returned_memories = [] for message_dict in messages: - if not isinstance(message_dict, dict) or \ - message_dict.get("role") is None or \ - message_dict.get("content") is None: + if ( + not isinstance(message_dict, dict) + or message_dict.get("role") is None + or message_dict.get("content") is None + ): logger.warning(f"Skipping invalid message format (async): {message_dict}") continue @@ -1110,20 +1101,24 @@ class AsyncMemory(MemoryBase): per_msg_meta = deepcopy(metadata) per_msg_meta["role"] = message_dict["role"] - + actor_name = message_dict.get("name") if actor_name: per_msg_meta["actor_id"] = actor_name - + msg_content = message_dict["content"] msg_embeddings = await asyncio.to_thread(self.embedding_model.embed, msg_content, "add") mem_id = await self._create_memory(msg_content, msg_embeddings, per_msg_meta) - - returned_memories.append({ - "id": mem_id, "memory": msg_content, "event": "ADD", - "actor_id": actor_name if actor_name else None, - "role": message_dict["role"], - }) + + returned_memories.append( + { + "id": mem_id, + "memory": msg_content, + "event": "ADD", + "actor_id": actor_name if actor_name else None, + "role": message_dict["role"], + } + ) return returned_memories parsed_messages = parse_messages(messages) @@ -1142,17 +1137,21 @@ class AsyncMemory(MemoryBase): response = remove_code_blocks(response) new_retrieved_facts = json.loads(response)["facts"] except Exception as e: - logging.error(f"Error in new_retrieved_facts: {e}"); new_retrieved_facts = [] + logging.error(f"Error in new_retrieved_facts: {e}") + new_retrieved_facts = [] retrieved_old_memory = [] new_message_embeddings = {} - + async def process_fact_for_search(new_mem_content): embeddings = await asyncio.to_thread(self.embedding_model.embed, new_mem_content, "add") new_message_embeddings[new_mem_content] = embeddings existing_mems = await asyncio.to_thread( - self.vector_store.search, query=new_mem_content, vectors=embeddings, - limit=5, filters=filters, # 'filters' is query_filters_for_inference + self.vector_store.search, + query=new_mem_content, + vectors=embeddings, + limit=5, + filters=filters, # 'filters' is query_filters_for_inference ) return [{"id": mem.id, "text": mem.payload["data"]} for mem in existing_mems] @@ -1160,9 +1159,10 @@ class AsyncMemory(MemoryBase): search_results_list = await asyncio.gather(*search_tasks) for result_group in search_results_list: retrieved_old_memory.extend(result_group) - + unique_data = {} - for item in retrieved_old_memory: unique_data[item["id"]] = item + for item in retrieved_old_memory: + unique_data[item["id"]] = item retrieved_old_memory = list(unique_data.values()) logging.info(f"Total existing memories: {len(retrieved_old_memory)}") temp_uuid_mapping = {} @@ -1180,35 +1180,45 @@ class AsyncMemory(MemoryBase): response_format={"type": "json_object"}, ) except Exception as e: - logging.error(f"Error in new memory actions response: {e}"); response = "" - + logging.error(f"Error in new memory actions response: {e}") + response = "" + try: response = remove_code_blocks(response) new_memories_with_actions = json.loads(response) except Exception as e: - logging.error(f"Invalid JSON response: {e}"); new_memories_with_actions = {} + logging.error(f"Invalid JSON response: {e}") + new_memories_with_actions = {} - returned_memories = [] + returned_memories = [] try: memory_tasks = [] for resp in new_memories_with_actions.get("memory", []): logging.info(resp) try: action_text = resp.get("text") - if not action_text: continue + if not action_text: + continue event_type = resp.get("event") if event_type == "ADD": - task = asyncio.create_task(self._create_memory( - data=action_text, existing_embeddings=new_message_embeddings, - metadata=deepcopy(metadata) - )) + task = asyncio.create_task( + self._create_memory( + data=action_text, + existing_embeddings=new_message_embeddings, + metadata=deepcopy(metadata), + ) + ) memory_tasks.append((task, resp, "ADD", None)) elif event_type == "UPDATE": - task = asyncio.create_task(self._update_memory( - memory_id=temp_uuid_mapping[resp["id"]], data=action_text, - existing_embeddings=new_message_embeddings, metadata=deepcopy(metadata) - )) + task = asyncio.create_task( + self._update_memory( + memory_id=temp_uuid_mapping[resp["id"]], + data=action_text, + existing_embeddings=new_message_embeddings, + metadata=deepcopy(metadata), + ) + ) memory_tasks.append((task, resp, "UPDATE", temp_uuid_mapping[resp["id"]])) elif event_type == "DELETE": task = asyncio.create_task(self._delete_memory(memory_id=temp_uuid_mapping[resp.get("id")])) @@ -1217,31 +1227,30 @@ class AsyncMemory(MemoryBase): logging.info("NOOP for Memory (async).") except Exception as e: logging.error(f"Error processing memory action (async): {resp}, Error: {e}") - + for task, resp, event_type, mem_id in memory_tasks: try: result_id = await task if event_type == "ADD": - returned_memories.append({ - "id": result_id, "memory": resp.get("text"), "event": event_type - }) + returned_memories.append({"id": result_id, "memory": resp.get("text"), "event": event_type}) elif event_type == "UPDATE": - returned_memories.append({ - "id": mem_id, "memory": resp.get("text"), - "event": event_type, "previous_memory": resp.get("old_memory") - }) + returned_memories.append( + { + "id": mem_id, + "memory": resp.get("text"), + "event": event_type, + "previous_memory": resp.get("old_memory"), + } + ) elif event_type == "DELETE": - returned_memories.append({ - "id": mem_id, "memory": resp.get("text"), "event": event_type - }) + returned_memories.append({"id": mem_id, "memory": resp.get("text"), "event": event_type}) except Exception as e: logging.error(f"Error awaiting memory task (async): {e}") except Exception as e: logging.error(f"Error in memory processing loop (async): {e}") - + capture_event( - "mem0.add", self, - {"version": self.api_version, "keys": list(filters.keys()), "sync_type": "async"} + "mem0.add", self, {"version": self.api_version, "keys": list(filters.keys()), "sync_type": "async"} ) return returned_memories @@ -1272,17 +1281,14 @@ class AsyncMemory(MemoryBase): return None promoted_payload_keys = [ - "user_id", - "agent_id", - "run_id", + "user_id", + "agent_id", + "run_id", "actor_id", "role", ] - - core_and_promoted_keys = { - "data", "hash", "created_at", "updated_at", "id", - *promoted_payload_keys - } + + core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys} result_item = MemoryItem( id=memory.id, @@ -1295,18 +1301,16 @@ class AsyncMemory(MemoryBase): for key in promoted_payload_keys: if key in memory.payload: result_item[key] = memory.payload[key] - - additional_metadata = { - k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys - } + + additional_metadata = {k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys} if additional_metadata: result_item["metadata"] = additional_metadata - + return result_item async def get_all( self, - *, + *, user_id: Optional[str] = None, agent_id: Optional[str] = None, run_id: Optional[str] = None, @@ -1314,41 +1318,36 @@ class AsyncMemory(MemoryBase): limit: int = 100, ): """ - List all memories. + List all memories. - Args: - user_id (str, optional): user id - agent_id (str, optional): agent id - run_id (str, optional): run id - filters (dict, optional): Additional custom key-value filters to apply to the search. - These are merged with the ID-based scoping filters. For example, - `filters={"actor_id": "some_user"}`. - limit (int, optional): The maximum number of memories to return. Defaults to 100. + Args: + user_id (str, optional): user id + agent_id (str, optional): agent id + run_id (str, optional): run id + filters (dict, optional): Additional custom key-value filters to apply to the search. + These are merged with the ID-based scoping filters. For example, + `filters={"actor_id": "some_user"}`. + limit (int, optional): The maximum number of memories to return. Defaults to 100. - Returns: - dict: A dictionary containing a list of memories under the "results" key, - and potentially "relations" if graph store is enabled. For API v1.0, - it might return a direct list (see deprecation warning). - Example for v1.1+: `{"results": [{"id": "...", "memory": "...", ...}]}` + Returns: + dict: A dictionary containing a list of memories under the "results" key, + and potentially "relations" if graph store is enabled. For API v1.0, + it might return a direct list (see deprecation warning). + Example for v1.1+: `{"results": [{"id": "...", "memory": "...", ...}]}` """ - + _, effective_filters = _build_filters_and_metadata( - user_id=user_id, - agent_id=agent_id, - run_id=run_id, - input_filters=filters + user_id=user_id, agent_id=agent_id, run_id=run_id, input_filters=filters ) if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")): - raise ValueError( - "When 'conversation_id' is not provided (classic mode), " - "at least one of 'user_id', 'agent_id', or 'run_id' must be specified for get_all." - ) + raise ValueError( + "When 'conversation_id' is not provided (classic mode), " + "at least one of 'user_id', 'agent_id', or 'run_id' must be specified for get_all." + ) capture_event( - "mem0.get_all", - self, - {"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "async"} + "mem0.get_all", self, {"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "async"} ) with concurrent.futures.ThreadPoolExecutor() as executor: @@ -1361,9 +1360,9 @@ class AsyncMemory(MemoryBase): [future_memories, future_graph_entities] if future_graph_entities else [future_memories] ) - all_memories_result = future_memories.result() + all_memories_result = future_memories.result() graph_entities_result = future_graph_entities.result() if future_graph_entities else None - + if self.enable_graph: return {"results": all_memories_result, "relations": graph_entities_result} @@ -1381,20 +1380,21 @@ class AsyncMemory(MemoryBase): async def _get_all_from_vector_store(self, filters, limit): memories_result = await asyncio.to_thread(self.vector_store.list, filters=filters, limit=limit) - actual_memories = memories_result[0] if isinstance(memories_result, tuple) and len(memories_result) > 0 else memories_result + actual_memories = ( + memories_result[0] if isinstance(memories_result, tuple) and len(memories_result) > 0 else memories_result + ) promoted_payload_keys = [ - "user_id", "agent_id", "run_id", + "user_id", + "agent_id", + "run_id", "actor_id", "role", ] - core_and_promoted_keys = { - "data", "hash", "created_at", "updated_at", "id", - *promoted_payload_keys - } + core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys} formatted_memories = [] - for mem in actual_memories: + for mem in actual_memories: memory_item_dict = MemoryItem( id=mem.id, memory=mem.payload["data"], @@ -1406,15 +1406,13 @@ class AsyncMemory(MemoryBase): for key in promoted_payload_keys: if key in mem.payload: memory_item_dict[key] = mem.payload[key] - - additional_metadata = { - k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys - } + + additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys} if additional_metadata: memory_item_dict["metadata"] = additional_metadata - + formatted_memories.append(memory_item_dict) - + return formatted_memories async def search( @@ -1442,16 +1440,13 @@ class AsyncMemory(MemoryBase): and potentially "relations" if graph store is enabled. Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "score": 0.8, ...}]}` """ - + _, effective_filters = _build_filters_and_metadata( - user_id=user_id, - agent_id=agent_id, - run_id=run_id, - input_filters=filters + user_id=user_id, agent_id=agent_id, run_id=run_id, input_filters=filters ) if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")): - raise ValueError("at least one of 'user_id', 'agent_id', or 'run_id' must be specified ") + raise ValueError("at least one of 'user_id', 'agent_id', or 'run_id' must be specified ") capture_event( "mem0.search", @@ -1460,22 +1455,20 @@ class AsyncMemory(MemoryBase): ) vector_store_task = asyncio.create_task(self._search_vector_store(query, effective_filters, limit)) - + graph_task = None if self.enable_graph: if hasattr(self.graph.search, "__await__"): # Check if graph search is async graph_task = asyncio.create_task(self.graph.search(query, effective_filters, limit)) else: - graph_task = asyncio.create_task( - asyncio.to_thread(self.graph.search, query, effective_filters, limit) - ) - + graph_task = asyncio.create_task(asyncio.to_thread(self.graph.search, query, effective_filters, limit)) + if graph_task: original_memories, graph_entities = await asyncio.gather(vector_store_task, graph_task) else: original_memories = await vector_store_task graph_entities = None - + if self.enable_graph: return {"results": original_memories, "relations": graph_entities} @@ -1504,11 +1497,8 @@ class AsyncMemory(MemoryBase): "actor_id", "role", ] - - core_and_promoted_keys = { - "data", "hash", "created_at", "updated_at", "id", - *promoted_payload_keys - } + + core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys} original_memories = [] for mem in memories: @@ -1518,19 +1508,17 @@ class AsyncMemory(MemoryBase): hash=mem.payload.get("hash"), created_at=mem.payload.get("created_at"), updated_at=mem.payload.get("updated_at"), - score=mem.score, - ).model_dump() + score=mem.score, + ).model_dump() for key in promoted_payload_keys: if key in mem.payload: memory_item_dict[key] = mem.payload[key] - - additional_metadata = { - k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys - } + + additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys} if additional_metadata: memory_item_dict["metadata"] = additional_metadata - + original_memories.append(memory_item_dict) return original_memories @@ -1650,7 +1638,7 @@ class AsyncMemory(MemoryBase): capture_event("mem0._create_memory", self, {"memory_id": memory_id, "sync_type": "async"}) return memory_id - async def _create_procedural_memory(self, messages, metadata=None,llm=None ,prompt=None): + async def _create_procedural_memory(self, messages, metadata=None, llm=None, prompt=None): """ Create a procedural memory asynchronously @@ -1709,11 +1697,11 @@ class AsyncMemory(MemoryBase): except Exception: logger.error(f"Error getting memory with ID {memory_id} during update.") raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'") - + prev_value = existing_memory.payload.get("data") new_metadata = deepcopy(metadata) if metadata is not None else {} - + new_metadata["data"] = data new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest() new_metadata["created_at"] = existing_memory.payload.get("created_at") @@ -1725,8 +1713,7 @@ class AsyncMemory(MemoryBase): new_metadata["agent_id"] = existing_memory.payload["agent_id"] if "run_id" in existing_memory.payload: new_metadata["run_id"] = existing_memory.payload["run_id"] - - + if "actor_id" in existing_memory.payload: new_metadata["actor_id"] = existing_memory.payload["actor_id"] if "role" in existing_memory.payload: @@ -1736,7 +1723,7 @@ class AsyncMemory(MemoryBase): embeddings = existing_embeddings[data] else: embeddings = await asyncio.to_thread(self.embedding_model.embed, data, "update") - + await asyncio.to_thread( self.vector_store.update, vector_id=memory_id, @@ -1744,7 +1731,7 @@ class AsyncMemory(MemoryBase): payload=new_metadata, ) logger.info(f"Updating memory with ID {memory_id=} with {data=}") - + await asyncio.to_thread( self.db.add_history, memory_id, diff --git a/mem0/memory/memgraph_memory.py b/mem0/memory/memgraph_memory.py index 8071c1c0..5a7cf6ec 100644 --- a/mem0/memory/memgraph_memory.py +++ b/mem0/memory/memgraph_memory.py @@ -5,16 +5,12 @@ from mem0.memory.utils import format_entities try: from langchain_memgraph import Memgraph except ImportError: - raise ImportError( - "langchain_memgraph is not installed. Please install it using pip install langchain-memgraph" - ) + raise ImportError("langchain_memgraph is not installed. Please install it using pip install langchain-memgraph") try: from rank_bm25 import BM25Okapi except ImportError: - raise ImportError( - "rank_bm25 is not installed. Please install it using pip install rank-bm25" - ) + raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25") from mem0.graphs.tools import ( DELETE_MEMORY_STRUCT_TOOL_GRAPH, @@ -74,22 +70,14 @@ class MemoryGraph: filters (dict): A dictionary containing filters to be applied during the addition. """ entity_type_map = self._retrieve_nodes_from_data(data, filters) - to_be_added = self._establish_nodes_relations_from_data( - data, filters, entity_type_map - ) - search_output = self._search_graph_db( - node_list=list(entity_type_map.keys()), filters=filters - ) - to_be_deleted = self._get_delete_entities_from_search_output( - search_output, data, filters - ) + to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map) + search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters) + to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters) # TODO: Batch queries with APOC plugin # TODO: Add more filter support deleted_entities = self._delete_entities(to_be_deleted, filters["user_id"]) - added_entities = self._add_entities( - to_be_added, filters["user_id"], entity_type_map - ) + added_entities = self._add_entities(to_be_added, filters["user_id"], entity_type_map) return {"deleted_entities": deleted_entities, "added_entities": added_entities} @@ -108,16 +96,13 @@ class MemoryGraph: - "entities": List of related graph data based on the query. """ entity_type_map = self._retrieve_nodes_from_data(query, filters) - search_output = self._search_graph_db( - node_list=list(entity_type_map.keys()), filters=filters - ) + search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters) if not search_output: return [] search_outputs_sequence = [ - [item["source"], item["relationship"], item["destination"]] - for item in search_output + [item["source"], item["relationship"], item["destination"]] for item in search_output ] bm25 = BM25Okapi(search_outputs_sequence) @@ -126,9 +111,7 @@ class MemoryGraph: search_results = [] for item in reranked_results: - search_results.append( - {"source": item[0], "relationship": item[1], "destination": item[2]} - ) + search_results.append({"source": item[0], "relationship": item[1], "destination": item[2]}) logger.info(f"Returned {len(search_results)} search results") @@ -161,9 +144,7 @@ class MemoryGraph: RETURN n.name AS source, type(r) AS relationship, m.name AS target LIMIT $limit """ - results = self.graph.query( - query, params={"user_id": filters["user_id"], "limit": limit} - ) + results = self.graph.query(query, params={"user_id": filters["user_id"], "limit": limit}) final_results = [] for result in results: @@ -208,13 +189,8 @@ class MemoryGraph: f"Error in search tool: {e}, llm_provider={self.llm_provider}, search_results={search_results}" ) - entity_type_map = { - k.lower().replace(" ", "_"): v.lower().replace(" ", "_") - for k, v in entity_type_map.items() - } - logger.debug( - f"Entity type map: {entity_type_map}\n search_results={search_results}" - ) + entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()} + logger.debug(f"Entity type map: {entity_type_map}\n search_results={search_results}") return entity_type_map def _establish_nodes_relations_from_data(self, data, filters, entity_type_map): @@ -223,9 +199,7 @@ class MemoryGraph: messages = [ { "role": "system", - "content": EXTRACT_RELATIONS_PROMPT.replace( - "USER_ID", filters["user_id"] - ).replace( + "content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]).replace( "CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}" ), }, @@ -235,9 +209,7 @@ class MemoryGraph: messages = [ { "role": "system", - "content": EXTRACT_RELATIONS_PROMPT.replace( - "USER_ID", filters["user_id"] - ), + "content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]), }, { "role": "user", @@ -304,9 +276,7 @@ class MemoryGraph: def _get_delete_entities_from_search_output(self, search_output, data, filters): """Get the entities to be deleted from the search output.""" search_output_string = format_entities(search_output) - system_prompt, user_prompt = get_delete_messages( - search_output_string, data, filters["user_id"] - ) + system_prompt, user_prompt = get_delete_messages(search_output_string, data, filters["user_id"]) _tools = [DELETE_MEMORY_TOOL_GRAPH] if self.llm_provider in ["azure_openai_structured", "openai_structured"]: @@ -379,12 +349,8 @@ class MemoryGraph: # search for the nodes with the closest embeddings; this is basically # comparison of one embedding to all embeddings in a graph -> vector # search with cosine similarity metric - source_node_search_result = self._search_source_node( - source_embedding, user_id, threshold=0.9 - ) - destination_node_search_result = self._search_destination_node( - dest_embedding, user_id, threshold=0.9 - ) + source_node_search_result = self._search_source_node(source_embedding, user_id, threshold=0.9) + destination_node_search_result = self._search_destination_node(dest_embedding, user_id, threshold=0.9) # TODO: Create a cypher query and common params for all the cases if not destination_node_search_result and source_node_search_result: @@ -424,9 +390,7 @@ class MemoryGraph: """ params = { - "destination_id": destination_node_search_result[0][ - "id(destination_candidate)" - ], + "destination_id": destination_node_search_result[0]["id(destination_candidate)"], "source_name": source, "source_embedding": source_embedding, "user_id": user_id, @@ -445,9 +409,7 @@ class MemoryGraph: """ params = { "source_id": source_node_search_result[0]["id(source_candidate)"], - "destination_id": destination_node_search_result[0][ - "id(destination_candidate)" - ], + "destination_id": destination_node_search_result[0]["id(destination_candidate)"], "user_id": user_id, } else: diff --git a/mem0/memory/storage.py b/mem0/memory/storage.py index 982ee020..7df0e000 100644 --- a/mem0/memory/storage.py +++ b/mem0/memory/storage.py @@ -1,8 +1,8 @@ +import logging import sqlite3 import threading import uuid -import logging -from typing import List, Dict, Any, Optional +from typing import Any, Dict, List, Optional logger = logging.getLogger(__name__) @@ -23,9 +23,7 @@ class SQLiteManager: """ with self._lock, self.connection: cur = self.connection.cursor() - cur.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name='history'" - ) + cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='history'") if cur.fetchone() is None: return # nothing to migrate @@ -51,13 +49,11 @@ class SQLiteManager: logger.info("Migrating history table to new schema (no convo columns).") cur.execute("ALTER TABLE history RENAME TO history_old") - self._create_history_table() + self._create_history_table() intersecting = list(expected_cols & old_cols) cols_csv = ", ".join(intersecting) - cur.execute( - f"INSERT INTO history ({cols_csv}) SELECT {cols_csv} FROM history_old" - ) + cur.execute(f"INSERT INTO history ({cols_csv}) SELECT {cols_csv} FROM history_old") cur.execute("DROP TABLE history_old") def _create_history_table(self) -> None: diff --git a/mem0/memory/telemetry.py b/mem0/memory/telemetry.py index d4ad1840..6d822cec 100644 --- a/mem0/memory/telemetry.py +++ b/mem0/memory/telemetry.py @@ -9,8 +9,8 @@ import mem0 from mem0.memory.setup import get_or_create_user_id MEM0_TELEMETRY = os.environ.get("MEM0_TELEMETRY", "True") -PROJECT_API_KEY="phc_hgJkUVJFYtmaJqrvf6CYN67TIQ8yhXAkWzUn9AMU4yX" -HOST="https://us.i.posthog.com" +PROJECT_API_KEY = "phc_hgJkUVJFYtmaJqrvf6CYN67TIQ8yhXAkWzUn9AMU4yX" +HOST = "https://us.i.posthog.com" if isinstance(MEM0_TELEMETRY, str): MEM0_TELEMETRY = MEM0_TELEMETRY.lower() in ("true", "1", "yes") diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index ce75cb7e..93ebbd8a 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -98,9 +98,8 @@ class VectorStoreFactory: return vector_store_instance(**config) else: raise ValueError(f"Unsupported VectorStore provider: {provider_name}") - + @classmethod def reset(cls, instance): instance.reset() return instance - diff --git a/mem0/vector_stores/azure_ai_search.py b/mem0/vector_stores/azure_ai_search.py index efcd528a..6acd2e40 100644 --- a/mem0/vector_stores/azure_ai_search.py +++ b/mem0/vector_stores/azure_ai_search.py @@ -377,4 +377,3 @@ class AzureAISearch(VectorStoreBase): except Exception as e: logger.error(f"Error resetting index {self.index_name}: {e}") raise - diff --git a/mem0/vector_stores/base.py b/mem0/vector_stores/base.py index 98f3503b..3e22499d 100644 --- a/mem0/vector_stores/base.py +++ b/mem0/vector_stores/base.py @@ -51,7 +51,7 @@ class VectorStoreBase(ABC): def list(self, filters=None, limit=None): """List all memories.""" pass - + @abstractmethod def reset(self): """Reset by delete the collection and recreate it.""" diff --git a/mem0/vector_stores/chroma.py b/mem0/vector_stores/chroma.py index ae4b03f7..1de95ad1 100644 --- a/mem0/vector_stores/chroma.py +++ b/mem0/vector_stores/chroma.py @@ -221,7 +221,7 @@ class ChromaDB(VectorStoreBase): """ results = self.collection.get(where=filters, limit=limit) return [self._parse_output(results)] - + def reset(self): """Reset the index by deleting and recreating it.""" logger.warning(f"Resetting index {self.collection_name}...") diff --git a/mem0/vector_stores/elasticsearch.py b/mem0/vector_stores/elasticsearch.py index 4d733a45..016f3783 100644 --- a/mem0/vector_stores/elasticsearch.py +++ b/mem0/vector_stores/elasticsearch.py @@ -58,7 +58,12 @@ class ElasticsearchDB(VectorStoreBase): "mappings": { "properties": { "text": {"type": "text"}, - "vector": {"type": "dense_vector", "dims": self.embedding_model_dims, "index": True, "similarity": "cosine"}, + "vector": { + "type": "dense_vector", + "dims": self.embedding_model_dims, + "index": True, + "similarity": "cosine", + }, "metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}}, } }, @@ -222,7 +227,7 @@ class ElasticsearchDB(VectorStoreBase): ) return [results] - + def reset(self): """Reset the index by deleting and recreating it.""" logger.warning(f"Resetting index {self.collection_name}...") diff --git a/mem0/vector_stores/faiss.py b/mem0/vector_stores/faiss.py index c042738c..cb2cd225 100644 --- a/mem0/vector_stores/faiss.py +++ b/mem0/vector_stores/faiss.py @@ -465,7 +465,7 @@ class FAISS(VectorStoreBase): break return [results] - + def reset(self): """Reset the index by deleting and recreating it.""" logger.warning(f"Resetting index {self.collection_name}...") diff --git a/mem0/vector_stores/langchain.py b/mem0/vector_stores/langchain.py index 93807446..f3fcf07e 100644 --- a/mem0/vector_stores/langchain.py +++ b/mem0/vector_stores/langchain.py @@ -14,6 +14,7 @@ from mem0.vector_stores.base import VectorStoreBase logger = logging.getLogger(__name__) + class OutputData(BaseModel): id: Optional[str] # memory id score: Optional[float] # distance @@ -162,10 +163,7 @@ class Langchain(VectorStoreBase): if filters and "user_id" in filters: where_clause = {"user_id": filters["user_id"]} - result = self.client._collection.get( - where=where_clause, - limit=limit - ) + result = self.client._collection.get(where=where_clause, limit=limit) # Convert the result to the expected format if result and isinstance(result, dict): diff --git a/mem0/vector_stores/milvus.py b/mem0/vector_stores/milvus.py index 775006ff..656234e8 100644 --- a/mem0/vector_stores/milvus.py +++ b/mem0/vector_stores/milvus.py @@ -237,7 +237,7 @@ class MilvusDB(VectorStoreBase): obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata")) memories.append(obj) return [memories] - + def reset(self): """Reset the index by deleting and recreating it.""" logger.warning(f"Resetting index {self.collection_name}...") diff --git a/mem0/vector_stores/opensearch.py b/mem0/vector_stores/opensearch.py index b0e986f0..4f258f32 100644 --- a/mem0/vector_stores/opensearch.py +++ b/mem0/vector_stores/opensearch.py @@ -1,6 +1,6 @@ import logging -from typing import Any, Dict, List, Optional import time +from typing import Any, Dict, List, Optional try: from opensearchpy import OpenSearch, RequestsHttpConnection @@ -34,7 +34,7 @@ class OpenSearchDB(VectorStoreBase): use_ssl=config.use_ssl, verify_certs=config.verify_certs, connection_class=RequestsHttpConnection, - pool_maxsize=20 + pool_maxsize=20, ) self.collection_name = config.collection_name @@ -69,9 +69,7 @@ class OpenSearchDB(VectorStoreBase): def create_col(self, name: str, vector_size: int) -> None: """Create a new collection (index in OpenSearch).""" index_settings = { - "settings": { - "index.knn": True - }, + "settings": {"index.knn": True}, "mappings": { "properties": { "vector_field": { @@ -82,7 +80,7 @@ class OpenSearchDB(VectorStoreBase): "payload": {"type": "object"}, "id": {"type": "keyword"}, } - } + }, } if not self.client.indices.exists(index=name): @@ -102,9 +100,7 @@ class OpenSearchDB(VectorStoreBase): except Exception: retry_count += 1 if retry_count == max_retries: - raise TimeoutError( - f"Index {name} creation timed out after {max_retries} seconds" - ) + raise TimeoutError(f"Index {name} creation timed out after {max_retries} seconds") time.sleep(0.5) def insert( @@ -145,10 +141,7 @@ class OpenSearchDB(VectorStoreBase): } # Start building the full query - query_body = { - "size": limit * 2, - "query": None - } + query_body = {"size": limit * 2, "query": None} # Prepare filter conditions if applicable filter_clauses = [] @@ -156,18 +149,11 @@ class OpenSearchDB(VectorStoreBase): for key in ["user_id", "run_id", "agent_id"]: value = filters.get(key) if value: - filter_clauses.append({ - "term": {f"payload.{key}.keyword": value} - }) + filter_clauses.append({"term": {f"payload.{key}.keyword": value}}) # Combine knn with filters if needed if filter_clauses: - query_body["query"] = { - "bool": { - "must": knn_query, - "filter": filter_clauses - } - } + query_body["query"] = {"bool": {"must": knn_query, "filter": filter_clauses}} else: query_body["query"] = knn_query @@ -176,11 +162,7 @@ class OpenSearchDB(VectorStoreBase): hits = response["hits"]["hits"] results = [ - OutputData( - id=hit["_source"].get("id"), - score=hit["_score"], - payload=hit["_source"].get("payload", {}) - ) + OutputData(id=hit["_source"].get("id"), score=hit["_score"], payload=hit["_source"].get("payload", {})) for hit in hits ] return results @@ -188,13 +170,7 @@ class OpenSearchDB(VectorStoreBase): def delete(self, vector_id: str) -> None: """Delete a vector by custom ID.""" # First, find the document by custom ID - search_query = { - "query": { - "term": { - "id": vector_id - } - } - } + search_query = {"query": {"term": {"id": vector_id}}} response = self.client.search(index=self.collection_name, body=search_query) hits = response.get("hits", {}).get("hits", []) @@ -207,18 +183,11 @@ class OpenSearchDB(VectorStoreBase): # Delete using the actual document ID self.client.delete(index=self.collection_name, id=opensearch_id) - def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None: """Update a vector and its payload using the custom 'id' field.""" # First, find the document by custom ID - search_query = { - "query": { - "term": { - "id": vector_id - } - } - } + search_query = {"query": {"term": {"id": vector_id}}} response = self.client.search(index=self.collection_name, body=search_query) hits = response.get("hits", {}).get("hits", []) @@ -241,7 +210,6 @@ class OpenSearchDB(VectorStoreBase): except Exception: pass - def get(self, vector_id: str) -> Optional[OutputData]: """Retrieve a vector by ID.""" try: @@ -251,13 +219,7 @@ class OpenSearchDB(VectorStoreBase): self.create_col(self.collection_name, self.embedding_model_dims) return None - search_query = { - "query": { - "term": { - "id": vector_id - } - } - } + search_query = {"query": {"term": {"id": vector_id}}} response = self.client.search(index=self.collection_name, body=search_query) hits = response["hits"]["hits"] @@ -265,11 +227,7 @@ class OpenSearchDB(VectorStoreBase): if not hits: return None - return OutputData( - id=hits[0]["_source"].get("id"), - score=1.0, - payload=hits[0]["_source"].get("payload", {}) - ) + return OutputData(id=hits[0]["_source"].get("id"), score=1.0, payload=hits[0]["_source"].get("payload", {})) except Exception as e: logger.error(f"Error retrieving vector {vector_id}: {str(e)}") return None @@ -287,30 +245,19 @@ class OpenSearchDB(VectorStoreBase): return self.client.indices.get(index=name) def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[OutputData]: - try: """List all memories with optional filters.""" - query: Dict = { - "query": { - "match_all": {} - } - } + query: Dict = {"query": {"match_all": {}}} filter_clauses = [] if filters: for key in ["user_id", "run_id", "agent_id"]: value = filters.get(key) if value: - filter_clauses.append({ - "term": {f"payload.{key}.keyword": value} - }) + filter_clauses.append({"term": {f"payload.{key}.keyword": value}}) if filter_clauses: - query["query"] = { - "bool": { - "filter": filter_clauses - } - } + query["query"] = {"bool": {"filter": filter_clauses}} if limit: query["size"] = limit @@ -318,18 +265,15 @@ class OpenSearchDB(VectorStoreBase): response = self.client.search(index=self.collection_name, body=query) hits = response["hits"]["hits"] - return [[ - OutputData( - id=hit["_source"].get("id"), - score=1.0, - payload=hit["_source"].get("payload", {}) - ) - for hit in hits - ]] + return [ + [ + OutputData(id=hit["_source"].get("id"), score=1.0, payload=hit["_source"].get("payload", {})) + for hit in hits + ] + ] except Exception: return [] - def reset(self): """Reset the index by deleting and recreating it.""" logger.warning(f"Resetting index {self.collection_name}...") diff --git a/mem0/vector_stores/pgvector.py b/mem0/vector_stores/pgvector.py index cc33077f..3b5d4157 100644 --- a/mem0/vector_stores/pgvector.py +++ b/mem0/vector_stores/pgvector.py @@ -286,7 +286,7 @@ class PGVector(VectorStoreBase): self.cur.close() if hasattr(self, "conn"): self.conn.close() - + def reset(self): """Reset the index by deleting and recreating it.""" logger.warning(f"Resetting index {self.collection_name}...") diff --git a/mem0/vector_stores/qdrant.py b/mem0/vector_stores/qdrant.py index 3703878c..33c3b342 100644 --- a/mem0/vector_stores/qdrant.py +++ b/mem0/vector_stores/qdrant.py @@ -232,7 +232,7 @@ class Qdrant(VectorStoreBase): with_vectors=False, ) return result - + def reset(self): """Reset the index by deleting and recreating it.""" logger.warning(f"Resetting index {self.collection_name}...") diff --git a/mem0/vector_stores/redis.py b/mem0/vector_stores/redis.py index 293d69ec..d2819975 100644 --- a/mem0/vector_stores/redis.py +++ b/mem0/vector_stores/redis.py @@ -88,7 +88,7 @@ class RedisDB(VectorStoreBase): The created index object. """ # Use provided parameters or fall back to instance attributes - collection_name = name or self.schema['index']['name'] + collection_name = name or self.schema["index"]["name"] embedding_dims = vector_size or self.embedding_model_dims distance_metric = distance or "cosine" @@ -237,17 +237,16 @@ class RedisDB(VectorStoreBase): """ Reset the index by deleting and recreating it. """ - collection_name = self.schema['index']['name'] + collection_name = self.schema["index"]["name"] logger.warning(f"Resetting index {collection_name}...") self.delete_col() - + self.index = SearchIndex.from_dict(self.schema) self.index.set_client(self.client) self.index.create(overwrite=True) - - #or use - #self.create_col(collection_name, self.embedding_model_dims) + # or use + # self.create_col(collection_name, self.embedding_model_dims) # Recreate the index with the same parameters self.create_col(collection_name, self.embedding_model_dims) diff --git a/mem0/vector_stores/supabase.py b/mem0/vector_stores/supabase.py index 9d0053d1..e55a979c 100644 --- a/mem0/vector_stores/supabase.py +++ b/mem0/vector_stores/supabase.py @@ -229,7 +229,7 @@ class Supabase(VectorStoreBase): records = self.collection.fetch(ids=ids) return [[OutputData(id=str(record[0]), score=None, payload=record[2]) for record in records]] - + def reset(self): """Reset the index by deleting and recreating it.""" logger.warning(f"Resetting index {self.collection_name}...") diff --git a/mem0/vector_stores/upstash_vector.py b/mem0/vector_stores/upstash_vector.py index 6d9b6a06..82dc0f44 100644 --- a/mem0/vector_stores/upstash_vector.py +++ b/mem0/vector_stores/upstash_vector.py @@ -285,10 +285,9 @@ class UpstashVector(VectorStoreBase): - Per-namespace vector and pending vector counts """ return self.client.info() - + def reset(self): """ Reset the Upstash Vector index. """ self.delete_col() - diff --git a/mem0/vector_stores/weaviate.py b/mem0/vector_stores/weaviate.py index b759780b..245fd227 100644 --- a/mem0/vector_stores/weaviate.py +++ b/mem0/vector_stores/weaviate.py @@ -308,7 +308,7 @@ class Weaviate(VectorStoreBase): payload["id"] = str(obj.uuid).split("'")[0] results.append(OutputData(id=str(obj.uuid).split("'")[0], score=1.0, payload=payload)) return [results] - + def reset(self): """Reset the index by deleting and recreating it.""" logger.warning(f"Resetting index {self.collection_name}...") diff --git a/server/main.py b/server/main.py index e12d9b31..150ae590 100644 --- a/server/main.py +++ b/server/main.py @@ -44,31 +44,14 @@ DEFAULT_CONFIG = { "user": POSTGRES_USER, "password": POSTGRES_PASSWORD, "collection_name": POSTGRES_COLLECTION_NAME, - } + }, }, "graph_store": { "provider": "neo4j", - "config": { - "url": NEO4J_URI, - "username": NEO4J_USERNAME, - "password": NEO4J_PASSWORD - } - }, - "llm": { - "provider": "openai", - "config": { - "api_key": OPENAI_API_KEY, - "temperature": 0.2, - "model": "gpt-4o" - } - }, - "embedder": { - "provider": "openai", - "config": { - "api_key": OPENAI_API_KEY, - "model": "text-embedding-3-small" - } + "config": {"url": NEO4J_URI, "username": NEO4J_USERNAME, "password": NEO4J_PASSWORD}, }, + "llm": {"provider": "openai", "config": {"api_key": OPENAI_API_KEY, "temperature": 0.2, "model": "gpt-4o"}}, + "embedder": {"provider": "openai", "config": {"api_key": OPENAI_API_KEY, "model": "text-embedding-3-small"}}, "history_db_path": HISTORY_DB_PATH, } @@ -115,9 +98,7 @@ def set_config(config: Dict[str, Any]): def add_memory(memory_create: MemoryCreate): """Store new memories.""" if not any([memory_create.user_id, memory_create.agent_id, memory_create.run_id]): - raise HTTPException( - status_code=400, detail="At least one identifier (user_id, agent_id, run_id) is required." - ) + raise HTTPException(status_code=400, detail="At least one identifier (user_id, agent_id, run_id) is required.") params = {k: v for k, v in memory_create.model_dump().items() if v is not None and k != "messages"} try: @@ -138,7 +119,9 @@ def get_all_memories( if not any([user_id, run_id, agent_id]): raise HTTPException(status_code=400, detail="At least one identifier is required.") try: - params = {k: v for k, v in {"user_id": user_id, "run_id": run_id, "agent_id": agent_id}.items() if v is not None} + params = { + k: v for k, v in {"user_id": user_id, "run_id": run_id, "agent_id": agent_id}.items() if v is not None + } return MEMORY_INSTANCE.get_all(**params) except Exception as e: logging.exception("Error in get_all_memories:") @@ -207,7 +190,9 @@ def delete_all_memories( if not any([user_id, run_id, agent_id]): raise HTTPException(status_code=400, detail="At least one identifier is required.") try: - params = {k: v for k, v in {"user_id": user_id, "run_id": run_id, "agent_id": agent_id}.items() if v is not None} + params = { + k: v for k, v in {"user_id": user_id, "run_id": run_id, "agent_id": agent_id}.items() if v is not None + } MEMORY_INSTANCE.delete_all(**params) return {"message": "All relevant memories deleted"} except Exception as e: @@ -229,4 +214,4 @@ def reset_memory(): @app.get("/", summary="Redirect to the OpenAPI documentation", include_in_schema=False) def home(): """Redirect to the OpenAPI documentation.""" - return RedirectResponse(url='/docs') + return RedirectResponse(url="/docs") diff --git a/tests/configs/test_prompts.py b/tests/configs/test_prompts.py index 7fc50fec..e978f8c9 100644 --- a/tests/configs/test_prompts.py +++ b/tests/configs/test_prompts.py @@ -5,13 +5,15 @@ def test_get_update_memory_messages(): retrieved_old_memory_dict = [{"id": "1", "text": "old memory 1"}] response_content = ["new fact"] custom_update_memory_prompt = "custom prompt determining memory update" - + ## When custom update memory prompt is provided ## - result = prompts.get_update_memory_messages(retrieved_old_memory_dict, response_content, custom_update_memory_prompt) + result = prompts.get_update_memory_messages( + retrieved_old_memory_dict, response_content, custom_update_memory_prompt + ) assert result.startswith(custom_update_memory_prompt) - + ## When custom update memory prompt is not provided ## result = prompts.get_update_memory_messages(retrieved_old_memory_dict, response_content, None) - assert result.startswith(prompts.DEFAULT_UPDATE_MEMORY_PROMPT) \ No newline at end of file + assert result.startswith(prompts.DEFAULT_UPDATE_MEMORY_PROMPT) diff --git a/tests/embeddings/test_lm_studio_embeddings.py b/tests/embeddings/test_lm_studio_embeddings.py index 55537bcd..e37476c8 100644 --- a/tests/embeddings/test_lm_studio_embeddings.py +++ b/tests/embeddings/test_lm_studio_embeddings.py @@ -10,9 +10,7 @@ from mem0.embeddings.lmstudio import LMStudioEmbedding def mock_lm_studio_client(): with patch("mem0.embeddings.lmstudio.OpenAI") as mock_openai: mock_client = Mock() - mock_client.embeddings.create.return_value = Mock( - data=[Mock(embedding=[0.1, 0.2, 0.3, 0.4, 0.5])] - ) + mock_client.embeddings.create.return_value = Mock(data=[Mock(embedding=[0.1, 0.2, 0.3, 0.4, 0.5])]) mock_openai.return_value = mock_client yield mock_client diff --git a/tests/embeddings/test_openai_embeddings.py b/tests/embeddings/test_openai_embeddings.py index ae365985..08dda117 100644 --- a/tests/embeddings/test_openai_embeddings.py +++ b/tests/embeddings/test_openai_embeddings.py @@ -23,7 +23,9 @@ def test_embed_default_model(mock_openai_client): result = embedder.embed("Hello world") - mock_openai_client.embeddings.create.assert_called_once_with(input=["Hello world"], model="text-embedding-3-small", dimensions = 1536) + mock_openai_client.embeddings.create.assert_called_once_with( + input=["Hello world"], model="text-embedding-3-small", dimensions=1536 + ) assert result == [0.1, 0.2, 0.3] @@ -37,7 +39,7 @@ def test_embed_custom_model(mock_openai_client): result = embedder.embed("Test embedding") mock_openai_client.embeddings.create.assert_called_once_with( - input=["Test embedding"], model="text-embedding-2-medium", dimensions = 1024 + input=["Test embedding"], model="text-embedding-2-medium", dimensions=1024 ) assert result == [0.4, 0.5, 0.6] @@ -51,7 +53,9 @@ def test_embed_removes_newlines(mock_openai_client): result = embedder.embed("Hello\nworld") - mock_openai_client.embeddings.create.assert_called_once_with(input=["Hello world"], model="text-embedding-3-small", dimensions = 1536) + mock_openai_client.embeddings.create.assert_called_once_with( + input=["Hello world"], model="text-embedding-3-small", dimensions=1536 + ) assert result == [0.7, 0.8, 0.9] @@ -65,7 +69,7 @@ def test_embed_without_api_key_env_var(mock_openai_client): result = embedder.embed("Testing API key") mock_openai_client.embeddings.create.assert_called_once_with( - input=["Testing API key"], model="text-embedding-3-small", dimensions = 1536 + input=["Testing API key"], model="text-embedding-3-small", dimensions=1536 ) assert result == [1.0, 1.1, 1.2] @@ -81,6 +85,6 @@ def test_embed_uses_environment_api_key(mock_openai_client, monkeypatch): result = embedder.embed("Environment key test") mock_openai_client.embeddings.create.assert_called_once_with( - input=["Environment key test"], model="text-embedding-3-small", dimensions = 1536 + input=["Environment key test"], model="text-embedding-3-small", dimensions=1536 ) assert result == [1.3, 1.4, 1.5] diff --git a/tests/embeddings/test_vertexai_embeddings.py b/tests/embeddings/test_vertexai_embeddings.py index 3353d673..9f541526 100644 --- a/tests/embeddings/test_vertexai_embeddings.py +++ b/tests/embeddings/test_vertexai_embeddings.py @@ -24,11 +24,20 @@ def mock_config(): with patch("mem0.configs.embeddings.base.BaseEmbedderConfig") as mock_config: mock_config.return_value.vertex_credentials_json = "/path/to/credentials.json" yield mock_config - + @pytest.fixture def mock_embedding_types(): - return ["SEMANTIC_SIMILARITY", "CLASSIFICATION", "CLUSTERING", "RETRIEVAL_DOCUMENT", "RETRIEVAL_QUERY", "QUESTION_ANSWERING", "FACT_VERIFICATION", "CODE_RETRIEVAL_QUERY"] + return [ + "SEMANTIC_SIMILARITY", + "CLASSIFICATION", + "CLUSTERING", + "RETRIEVAL_DOCUMENT", + "RETRIEVAL_QUERY", + "QUESTION_ANSWERING", + "FACT_VERIFICATION", + "CODE_RETRIEVAL_QUERY", + ] @pytest.fixture @@ -79,30 +88,31 @@ def test_embed_custom_model(mock_text_embedding_model, mock_os_environ, mock_con assert result == [0.4, 0.5, 0.6] -@patch("mem0.embeddings.vertexai.TextEmbeddingModel") -def test_embed_with_memory_action(mock_text_embedding_model, mock_os_environ, mock_config, mock_embedding_types, mock_text_embedding_input): +@patch("mem0.embeddings.vertexai.TextEmbeddingModel") +def test_embed_with_memory_action( + mock_text_embedding_model, mock_os_environ, mock_config, mock_embedding_types, mock_text_embedding_input +): mock_config.return_value.model = "text-embedding-004" mock_config.return_value.embedding_dims = 256 - + for embedding_type in mock_embedding_types: - mock_config.return_value.memory_add_embedding_type = embedding_type mock_config.return_value.memory_update_embedding_type = embedding_type mock_config.return_value.memory_search_embedding_type = embedding_type config = mock_config() embedder = VertexAIEmbedding(config) - + mock_text_embedding_model.from_pretrained.assert_called_with("text-embedding-004") for memory_action in ["add", "update", "search"]: embedder.embed("Hello world", memory_action=memory_action) - + mock_text_embedding_input.assert_called_with(text="Hello world", task_type=embedding_type) mock_text_embedding_model.from_pretrained.return_value.get_embeddings.assert_called_with( texts=[mock_text_embedding_input("Hello world", embedding_type)], output_dimensionality=256 ) - + @patch("mem0.embeddings.vertexai.os") def test_credentials_from_environment(mock_os, mock_text_embedding_model, mock_config): @@ -137,15 +147,15 @@ def test_embed_with_different_dimensions(mock_text_embedding_model, mock_os_envi result = embedder.embed("Large embedding test") assert result == [0.1] * 1024 - -@patch("mem0.embeddings.vertexai.TextEmbeddingModel") + +@patch("mem0.embeddings.vertexai.TextEmbeddingModel") def test_invalid_memory_action(mock_text_embedding_model, mock_config): mock_config.return_value.model = "text-embedding-004" mock_config.return_value.embedding_dims = 256 - + config = mock_config() embedder = VertexAIEmbedding(config) - + with pytest.raises(ValueError): - embedder.embed("Hello world", memory_action="invalid_action") \ No newline at end of file + embedder.embed("Hello world", memory_action="invalid_action") diff --git a/tests/llms/test_azure_openai.py b/tests/llms/test_azure_openai.py index f377f537..7ef86e94 100644 --- a/tests/llms/test_azure_openai.py +++ b/tests/llms/test_azure_openai.py @@ -127,4 +127,4 @@ def test_generate_with_http_proxies(default_headers): api_version=None, default_headers=default_headers, ) - mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000") \ No newline at end of file + mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000") diff --git a/tests/llms/test_deepseek.py b/tests/llms/test_deepseek.py index b15a8972..b7d3f5e9 100644 --- a/tests/llms/test_deepseek.py +++ b/tests/llms/test_deepseek.py @@ -31,12 +31,12 @@ def test_deepseek_llm_base_url(): # case3: with config.deepseek_base_url config_base_url = "https://api.config.com/v1/" config = BaseLlmConfig( - model="deepseek-chat", - temperature=0.7, - max_tokens=100, - top_p=1.0, - api_key="api_key", - deepseek_base_url=config_base_url + model="deepseek-chat", + temperature=0.7, + max_tokens=100, + top_p=1.0, + api_key="api_key", + deepseek_base_url=config_base_url, ) llm = DeepSeekLLM(config) assert str(llm.client.base_url) == config_base_url @@ -99,16 +99,16 @@ def test_generate_response_with_tools(mock_deepseek_client): response = llm.generate_response(messages, tools=tools) mock_deepseek_client.chat.completions.create.assert_called_once_with( - model="deepseek-chat", - messages=messages, - temperature=0.7, - max_tokens=100, - top_p=1.0, - tools=tools, - tool_choice="auto" + model="deepseek-chat", + messages=messages, + temperature=0.7, + max_tokens=100, + top_p=1.0, + tools=tools, + tool_choice="auto", ) assert response["content"] == "I've added the memory for you." assert len(response["tool_calls"]) == 1 assert response["tool_calls"][0]["name"] == "add_memory" - assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."} \ No newline at end of file + assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."} diff --git a/tests/llms/test_langchain.py b/tests/llms/test_langchain.py index 59ff407c..11764a6e 100644 --- a/tests/llms/test_langchain.py +++ b/tests/llms/test_langchain.py @@ -10,6 +10,7 @@ try: from langchain.chat_models.base import BaseChatModel except ImportError: from unittest.mock import MagicMock + BaseChatModel = MagicMock @@ -24,16 +25,11 @@ def mock_langchain_model(): def test_langchain_initialization(mock_langchain_model): """Test that LangchainLLM initializes correctly with a valid model.""" # Create a config with the model instance directly - config = BaseLlmConfig( - model=mock_langchain_model, - temperature=0.7, - max_tokens=100, - api_key="test-api-key" - ) - + config = BaseLlmConfig(model=mock_langchain_model, temperature=0.7, max_tokens=100, api_key="test-api-key") + # Initialize the LangchainLLM llm = LangchainLLM(config) - + # Verify the model was correctly assigned assert llm.langchain_model == mock_langchain_model @@ -41,35 +37,30 @@ def test_langchain_initialization(mock_langchain_model): def test_generate_response(mock_langchain_model): """Test that generate_response correctly processes messages and returns a response.""" # Create a config with the model instance - config = BaseLlmConfig( - model=mock_langchain_model, - temperature=0.7, - max_tokens=100, - api_key="test-api-key" - ) - + config = BaseLlmConfig(model=mock_langchain_model, temperature=0.7, max_tokens=100, api_key="test-api-key") + # Initialize the LangchainLLM llm = LangchainLLM(config) - + # Create test messages messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello, how are you?"}, {"role": "assistant", "content": "I'm doing well! How can I help you?"}, - {"role": "user", "content": "Tell me a joke."} + {"role": "user", "content": "Tell me a joke."}, ] - + # Get response response = llm.generate_response(messages) - + # Verify the correct message format was passed to the model expected_langchain_messages = [ ("system", "You are a helpful assistant."), ("human", "Hello, how are you?"), ("ai", "I'm doing well! How can I help you?"), - ("human", "Tell me a joke.") + ("human", "Tell me a joke."), ] - + mock_langchain_model.invoke.assert_called_once() # Extract the first argument of the first call actual_messages = mock_langchain_model.invoke.call_args[0][0] @@ -79,25 +70,15 @@ def test_generate_response(mock_langchain_model): def test_invalid_model(): """Test that LangchainLLM raises an error with an invalid model.""" - config = BaseLlmConfig( - model="not-a-valid-model-instance", - temperature=0.7, - max_tokens=100, - api_key="test-api-key" - ) - + config = BaseLlmConfig(model="not-a-valid-model-instance", temperature=0.7, max_tokens=100, api_key="test-api-key") + with pytest.raises(ValueError, match="`model` must be an instance of BaseChatModel"): LangchainLLM(config) def test_missing_model(): """Test that LangchainLLM raises an error when model is None.""" - config = BaseLlmConfig( - model=None, - temperature=0.7, - max_tokens=100, - api_key="test-api-key" - ) - + config = BaseLlmConfig(model=None, temperature=0.7, max_tokens=100, api_key="test-api-key") + with pytest.raises(ValueError, match="`model` parameter is required"): LangchainLLM(config) diff --git a/tests/llms/test_lm_studio.py b/tests/llms/test_lm_studio.py index 8d0e4871..bed1b98b 100644 --- a/tests/llms/test_lm_studio.py +++ b/tests/llms/test_lm_studio.py @@ -11,9 +11,7 @@ def mock_lm_studio_client(): with patch("mem0.llms.lmstudio.OpenAI") as mock_openai: # Corrected path mock_client = Mock() mock_client.chat.completions.create.return_value = Mock( - choices=[ - Mock(message=Mock(content="I'm doing well, thank you for asking!")) - ] + choices=[Mock(message=Mock(content="I'm doing well, thank you for asking!"))] ) mock_openai.return_value = mock_client yield mock_client diff --git a/tests/memory/test_main.py b/tests/memory/test_main.py index 18966e44..64a8f837 100644 --- a/tests/memory/test_main.py +++ b/tests/memory/test_main.py @@ -10,18 +10,19 @@ def _setup_mocks(mocker): """Helper to setup common mocks for both sync and async fixtures""" mock_embedder = mocker.MagicMock() mock_embedder.return_value.embed.return_value = [0.1, 0.2, 0.3] - mocker.patch('mem0.utils.factory.EmbedderFactory.create', mock_embedder) - + mocker.patch("mem0.utils.factory.EmbedderFactory.create", mock_embedder) + mock_vector_store = mocker.MagicMock() mock_vector_store.return_value.search.return_value = [] - mocker.patch('mem0.utils.factory.VectorStoreFactory.create', - side_effect=[mock_vector_store.return_value, mocker.MagicMock()]) - + mocker.patch( + "mem0.utils.factory.VectorStoreFactory.create", side_effect=[mock_vector_store.return_value, mocker.MagicMock()] + ) + mock_llm = mocker.MagicMock() - mocker.patch('mem0.utils.factory.LlmFactory.create', mock_llm) - - mocker.patch('mem0.memory.storage.SQLiteManager', mocker.MagicMock()) - + mocker.patch("mem0.utils.factory.LlmFactory.create", mock_llm) + + mocker.patch("mem0.memory.storage.SQLiteManager", mocker.MagicMock()) + return mock_llm, mock_vector_store @@ -30,29 +31,26 @@ class TestAddToVectorStoreErrors: def mock_memory(self, mocker): """Fixture that returns a Memory instance with mocker-based mocks""" mock_llm, _ = _setup_mocks(mocker) - + memory = Memory() memory.config = mocker.MagicMock() memory.config.custom_fact_extraction_prompt = None memory.config.custom_update_memory_prompt = None memory.api_version = "v1.1" - + return memory def test_empty_llm_response_fact_extraction(self, mock_memory, caplog): """Test empty response from LLM during fact extraction""" # Setup mock_memory.llm.generate_response.return_value = "" - + # Execute with caplog.at_level(logging.ERROR): result = mock_memory._add_to_vector_store( - messages=[{"role": "user", "content": "test"}], - metadata={}, - filters={}, - infer=True + messages=[{"role": "user", "content": "test"}], metadata={}, filters={}, infer=True ) - + # Verify assert mock_memory.llm.generate_response.call_count == 2 assert result == [] # Should return empty list when no memories processed @@ -62,20 +60,14 @@ class TestAddToVectorStoreErrors: """Test empty response from LLM during memory actions""" # Setup # First call returns valid JSON, second call returns empty string - mock_memory.llm.generate_response.side_effect = [ - '{"facts": ["test fact"]}', - "" - ] - + mock_memory.llm.generate_response.side_effect = ['{"facts": ["test fact"]}', ""] + # Execute with caplog.at_level(logging.ERROR): result = mock_memory._add_to_vector_store( - messages=[{"role": "user", "content": "test"}], - metadata={}, - filters={}, - infer=True + messages=[{"role": "user", "content": "test"}], metadata={}, filters={}, infer=True ) - + # Verify assert mock_memory.llm.generate_response.call_count == 2 assert result == [] # Should return empty list when no memories processed @@ -88,48 +80,39 @@ class TestAsyncAddToVectorStoreErrors: def mock_async_memory(self, mocker): """Fixture for AsyncMemory with mocker-based mocks""" mock_llm, _ = _setup_mocks(mocker) - + memory = AsyncMemory() memory.config = mocker.MagicMock() memory.config.custom_fact_extraction_prompt = None memory.config.custom_update_memory_prompt = None memory.api_version = "v1.1" - + return memory @pytest.mark.asyncio async def test_async_empty_llm_response_fact_extraction(self, mock_async_memory, caplog, mocker): """Test empty response in AsyncMemory._add_to_vector_store""" - mocker.patch('mem0.utils.factory.EmbedderFactory.create', return_value=MagicMock()) + mocker.patch("mem0.utils.factory.EmbedderFactory.create", return_value=MagicMock()) mock_async_memory.llm.generate_response.return_value = "" - + with caplog.at_level(logging.ERROR): result = await mock_async_memory._add_to_vector_store( - messages=[{"role": "user", "content": "test"}], - metadata={}, - filters={}, - infer=True + messages=[{"role": "user", "content": "test"}], metadata={}, filters={}, infer=True ) - + assert result == [] assert "Error in new_retrieved_facts" in caplog.text @pytest.mark.asyncio async def test_async_empty_llm_response_memory_actions(self, mock_async_memory, caplog, mocker): """Test empty response in AsyncMemory._add_to_vector_store""" - mocker.patch('mem0.utils.factory.EmbedderFactory.create', return_value=MagicMock()) - mock_async_memory.llm.generate_response.side_effect = [ - '{"facts": ["test fact"]}', - "" - ] - + mocker.patch("mem0.utils.factory.EmbedderFactory.create", return_value=MagicMock()) + mock_async_memory.llm.generate_response.side_effect = ['{"facts": ["test fact"]}', ""] + with caplog.at_level(logging.ERROR): result = await mock_async_memory._add_to_vector_store( - messages=[{"role": "user", "content": "test"}], - metadata={}, - filters={}, - infer=True + messages=[{"role": "user", "content": "test"}], metadata={}, filters={}, infer=True ) - + assert result == [] assert "Invalid JSON response" in caplog.text diff --git a/tests/test_main.py b/tests/test_main.py index 60135a7f..41afa05b 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -17,11 +17,13 @@ def mock_openai(): @pytest.fixture def memory_instance(): - with patch("mem0.utils.factory.EmbedderFactory") as mock_embedder, patch( - "mem0.utils.factory.VectorStoreFactory" - ) as mock_vector_store, patch("mem0.utils.factory.LlmFactory") as mock_llm, patch( - "mem0.memory.telemetry.capture_event" - ), patch("mem0.memory.graph_memory.MemoryGraph"): + with ( + patch("mem0.utils.factory.EmbedderFactory") as mock_embedder, + patch("mem0.utils.factory.VectorStoreFactory") as mock_vector_store, + patch("mem0.utils.factory.LlmFactory") as mock_llm, + patch("mem0.memory.telemetry.capture_event"), + patch("mem0.memory.graph_memory.MemoryGraph"), + ): mock_embedder.create.return_value = Mock() mock_vector_store.create.return_value = Mock() mock_llm.create.return_value = Mock() @@ -30,13 +32,16 @@ def memory_instance(): config.graph_store.config = {"some_config": "value"} return Memory(config) + @pytest.fixture def memory_custom_instance(): - with patch("mem0.utils.factory.EmbedderFactory") as mock_embedder, patch( - "mem0.utils.factory.VectorStoreFactory" - ) as mock_vector_store, patch("mem0.utils.factory.LlmFactory") as mock_llm, patch( - "mem0.memory.telemetry.capture_event" - ), patch("mem0.memory.graph_memory.MemoryGraph"): + with ( + patch("mem0.utils.factory.EmbedderFactory") as mock_embedder, + patch("mem0.utils.factory.VectorStoreFactory") as mock_vector_store, + patch("mem0.utils.factory.LlmFactory") as mock_llm, + patch("mem0.memory.telemetry.capture_event"), + patch("mem0.memory.graph_memory.MemoryGraph"), + ): mock_embedder.create.return_value = Mock() mock_vector_store.create.return_value = Mock() mock_llm.create.return_value = Mock() @@ -44,7 +49,7 @@ def memory_custom_instance(): config = MemoryConfig( version="v1.1", custom_fact_extraction_prompt="custom prompt extracting memory", - custom_update_memory_prompt="custom prompt determining memory update" + custom_update_memory_prompt="custom prompt determining memory update", ) config.graph_store.config = {"some_config": "value"} return Memory(config) @@ -194,7 +199,6 @@ def test_delete_all(memory_instance, version, enable_graph): assert result["message"] == "Memories deleted successfully!" - @pytest.mark.parametrize( "version, enable_graph, expected_result", [ @@ -242,20 +246,22 @@ def test_get_all(memory_instance, version, enable_graph, expected_result): memory_instance.graph.get_all.assert_called_once_with({"user_id": "test_user"}, 100) else: memory_instance.graph.get_all.assert_not_called() - + def test_custom_prompts(memory_custom_instance): messages = [{"role": "user", "content": "Test message"}] memory_custom_instance.llm.generate_response = Mock() - + with patch("mem0.memory.main.parse_messages", return_value="Test message") as mock_parse_messages: - with patch("mem0.memory.main.get_update_memory_messages", return_value="custom update memory prompt") as mock_get_update_memory_messages: + with patch( + "mem0.memory.main.get_update_memory_messages", return_value="custom update memory prompt" + ) as mock_get_update_memory_messages: memory_custom_instance.add(messages=messages, user_id="test_user") - + ## custom prompt ## mock_parse_messages.assert_called_once_with(messages) - + memory_custom_instance.llm.generate_response.assert_any_call( messages=[ {"role": "system", "content": memory_custom_instance.config.custom_fact_extraction_prompt}, @@ -263,12 +269,14 @@ def test_custom_prompts(memory_custom_instance): ], response_format={"type": "json_object"}, ) - + ## custom update memory prompt ## - mock_get_update_memory_messages.assert_called_once_with([],[],memory_custom_instance.config.custom_update_memory_prompt) - + mock_get_update_memory_messages.assert_called_once_with( + [], [], memory_custom_instance.config.custom_update_memory_prompt + ) + memory_custom_instance.llm.generate_response.assert_any_call( messages=[{"role": "user", "content": mock_get_update_memory_messages.return_value}], response_format={"type": "json_object"}, - ) \ No newline at end of file + ) diff --git a/tests/test_proxy.py b/tests/test_proxy.py index f7aaa1eb..ba318842 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -97,4 +97,4 @@ def test_completions_create_with_system_message(mock_memory_client, mock_litellm call_args = mock_litellm.completion.call_args[1] assert call_args["messages"][0]["role"] == "system" - assert call_args["messages"][0]["content"] == "You are a helpful assistant." \ No newline at end of file + assert call_args["messages"][0]["content"] == "You are a helpful assistant." diff --git a/tests/vector_stores/test_azure_ai_search.py b/tests/vector_stores/test_azure_ai_search.py index 8fee7b68..764ea322 100644 --- a/tests/vector_stores/test_azure_ai_search.py +++ b/tests/vector_stores/test_azure_ai_search.py @@ -13,13 +13,15 @@ from mem0.vector_stores.azure_ai_search import AzureAISearch # Fixture to patch SearchClient and SearchIndexClient and create an instance of AzureAISearch. @pytest.fixture def mock_clients(): - with patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient, \ - patch("mem0.vector_stores.azure_ai_search.SearchIndexClient") as MockIndexClient, \ - patch("mem0.vector_stores.azure_ai_search.AzureKeyCredential") as MockAzureKeyCredential: + with ( + patch("mem0.vector_stores.azure_ai_search.SearchClient") as MockSearchClient, + patch("mem0.vector_stores.azure_ai_search.SearchIndexClient") as MockIndexClient, + patch("mem0.vector_stores.azure_ai_search.AzureKeyCredential") as MockAzureKeyCredential, + ): # Create mocked instances for search and index clients. mock_search_client = MockSearchClient.return_value mock_index_client = MockIndexClient.return_value - + # Mock the client._client._config.user_agent_policy.add_user_agent mock_search_client._client = MagicMock() mock_search_client._client._config.user_agent_policy.add_user_agent = Mock() @@ -62,7 +64,7 @@ def azure_ai_search_instance(mock_clients): api_key="test-api-key", embedding_model_dims=3, compression_type="binary", # testing binary quantization option - use_float16=True + use_float16=True, ) # Return instance and clients for verification. return instance, mock_search_client, mock_index_client @@ -70,21 +72,18 @@ def azure_ai_search_instance(mock_clients): # --- Tests for AzureAISearchConfig --- + def test_config_validation_valid(): """Test valid configurations are accepted.""" # Test minimal configuration - config = AzureAISearchConfig( - service_name="test-service", - api_key="test-api-key", - embedding_model_dims=768 - ) + config = AzureAISearchConfig(service_name="test-service", api_key="test-api-key", embedding_model_dims=768) assert config.collection_name == "mem0" # Default value assert config.service_name == "test-service" assert config.api_key == "test-api-key" assert config.embedding_model_dims == 768 assert config.compression_type is None assert config.use_float16 is False - + # Test with all optional parameters config = AzureAISearchConfig( collection_name="custom-index", @@ -92,7 +91,7 @@ def test_config_validation_valid(): api_key="test-api-key", embedding_model_dims=1536, compression_type="scalar", - use_float16=True + use_float16=True, ) assert config.collection_name == "custom-index" assert config.compression_type == "scalar" @@ -106,7 +105,7 @@ def test_config_validation_invalid_compression_type(): service_name="test-service", api_key="test-api-key", embedding_model_dims=768, - compression_type="invalid-type" # Not a valid option + compression_type="invalid-type", # Not a valid option ) assert "Invalid compression_type" in str(exc_info.value) @@ -118,7 +117,7 @@ def test_config_validation_deprecated_use_compression(): service_name="test-service", api_key="test-api-key", embedding_model_dims=768, - use_compression=True # Deprecated parameter + use_compression=True, # Deprecated parameter ) # Fix: Use a partial string match instead of exact match assert "use_compression" in str(exc_info.value) @@ -132,7 +131,7 @@ def test_config_validation_extra_fields(): service_name="test-service", api_key="test-api-key", embedding_model_dims=768, - unknown_parameter="value" # Extra field + unknown_parameter="value", # Extra field ) assert "Extra fields not allowed" in str(exc_info.value) assert "unknown_parameter" in str(exc_info.value) @@ -140,30 +139,28 @@ def test_config_validation_extra_fields(): # --- Tests for AzureAISearch initialization --- + def test_initialization(mock_clients): """Test AzureAISearch initialization with different parameters.""" mock_search_client, mock_index_client, mock_azure_key_credential = mock_clients - + # Test with minimal parameters instance = AzureAISearch( - service_name="test-service", - collection_name="test-index", - api_key="test-api-key", - embedding_model_dims=768 + service_name="test-service", collection_name="test-index", api_key="test-api-key", embedding_model_dims=768 ) - + # Verify initialization parameters assert instance.index_name == "test-index" assert instance.collection_name == "test-index" assert instance.embedding_model_dims == 768 assert instance.compression_type == "none" # Default when None is passed assert instance.use_float16 is False - + # Verify client creation mock_azure_key_credential.assert_called_with("test-api-key") assert "mem0" in mock_search_client._client._config.user_agent_policy.add_user_agent.call_args[0] assert "mem0" in mock_index_client._client._config.user_agent_policy.add_user_agent.call_args[0] - + # Verify index creation was called mock_index_client.create_or_update_index.assert_called_once() @@ -171,75 +168,75 @@ def test_initialization(mock_clients): def test_initialization_with_compression_types(mock_clients): """Test initialization with different compression types.""" mock_search_client, mock_index_client, _ = mock_clients - + # Test with scalar compression instance = AzureAISearch( service_name="test-service", collection_name="scalar-index", api_key="test-api-key", embedding_model_dims=768, - compression_type="scalar" + compression_type="scalar", ) assert instance.compression_type == "scalar" - + # Capture the index creation call args, _ = mock_index_client.create_or_update_index.call_args_list[-1] index = args[0] # Verify scalar compression was configured - assert hasattr(index.vector_search, 'compressions') + assert hasattr(index.vector_search, "compressions") assert len(index.vector_search.compressions) > 0 assert "ScalarQuantizationCompression" in str(type(index.vector_search.compressions[0])) - + # Test with binary compression instance = AzureAISearch( service_name="test-service", collection_name="binary-index", api_key="test-api-key", embedding_model_dims=768, - compression_type="binary" + compression_type="binary", ) assert instance.compression_type == "binary" - + # Capture the index creation call args, _ = mock_index_client.create_or_update_index.call_args_list[-1] index = args[0] # Verify binary compression was configured - assert hasattr(index.vector_search, 'compressions') + assert hasattr(index.vector_search, "compressions") assert len(index.vector_search.compressions) > 0 assert "BinaryQuantizationCompression" in str(type(index.vector_search.compressions[0])) - + # Test with no compression instance = AzureAISearch( service_name="test-service", collection_name="no-compression-index", api_key="test-api-key", embedding_model_dims=768, - compression_type=None + compression_type=None, ) assert instance.compression_type == "none" - + # Capture the index creation call args, _ = mock_index_client.create_or_update_index.call_args_list[-1] index = args[0] # Verify no compression was configured - assert hasattr(index.vector_search, 'compressions') + assert hasattr(index.vector_search, "compressions") assert len(index.vector_search.compressions) == 0 def test_initialization_with_float_precision(mock_clients): """Test initialization with different float precision settings.""" mock_search_client, mock_index_client, _ = mock_clients - + # Test with half precision (float16) instance = AzureAISearch( service_name="test-service", collection_name="float16-index", api_key="test-api-key", embedding_model_dims=768, - use_float16=True + use_float16=True, ) assert instance.use_float16 is True - + # Capture the index creation call args, _ = mock_index_client.create_or_update_index.call_args_list[-1] index = args[0] @@ -247,17 +244,17 @@ def test_initialization_with_float_precision(mock_clients): vector_field = next((f for f in index.fields if f.name == "vector"), None) assert vector_field is not None assert "Edm.Half" in vector_field.type - + # Test with full precision (float32) instance = AzureAISearch( service_name="test-service", collection_name="float32-index", api_key="test-api-key", embedding_model_dims=768, - use_float16=False + use_float16=False, ) assert instance.use_float16 is False - + # Capture the index creation call args, _ = mock_index_client.create_or_update_index.call_args_list[-1] index = args[0] @@ -269,21 +266,22 @@ def test_initialization_with_float_precision(mock_clients): # --- Tests for create_col method --- + def test_create_col(azure_ai_search_instance): """Test the create_col method creates an index with the correct configuration.""" instance, _, mock_index_client = azure_ai_search_instance - + # create_col is called during initialization, so we check the call that was already made mock_index_client.create_or_update_index.assert_called_once() - + # Verify the index configuration args, _ = mock_index_client.create_or_update_index.call_args index = args[0] - + # Check basic properties assert index.name == "test-index" assert len(index.fields) == 6 # id, user_id, run_id, agent_id, vector, payload - + # Check that required fields are present field_names = [f.name for f in index.fields] assert "id" in field_names @@ -292,22 +290,22 @@ def test_create_col(azure_ai_search_instance): assert "user_id" in field_names assert "run_id" in field_names assert "agent_id" in field_names - + # Check that id is the key field id_field = next(f for f in index.fields if f.name == "id") assert id_field.key is True - + # Check vector search configuration assert index.vector_search is not None assert len(index.vector_search.profiles) == 1 assert index.vector_search.profiles[0].name == "my-vector-config" assert index.vector_search.profiles[0].algorithm_configuration_name == "my-algorithms-config" - + # Check algorithms assert len(index.vector_search.algorithms) == 1 assert index.vector_search.algorithms[0].name == "my-algorithms-config" assert "HnswAlgorithmConfiguration" in str(type(index.vector_search.algorithms[0])) - + # With binary compression and float16, we should have compression configuration assert len(index.vector_search.compressions) == 1 assert index.vector_search.compressions[0].compression_name == "myCompression" @@ -317,24 +315,24 @@ def test_create_col(azure_ai_search_instance): def test_create_col_scalar_compression(mock_clients): """Test creating a collection with scalar compression.""" mock_search_client, mock_index_client, _ = mock_clients - + AzureAISearch( service_name="test-service", collection_name="scalar-index", api_key="test-api-key", embedding_model_dims=768, - compression_type="scalar" + compression_type="scalar", ) - + # Verify the index configuration args, _ = mock_index_client.create_or_update_index.call_args index = args[0] - + # Check compression configuration assert len(index.vector_search.compressions) == 1 assert index.vector_search.compressions[0].compression_name == "myCompression" assert "ScalarQuantizationCompression" in str(type(index.vector_search.compressions[0])) - + # Check profile references compression assert index.vector_search.profiles[0].compression_name == "myCompression" @@ -342,28 +340,29 @@ def test_create_col_scalar_compression(mock_clients): def test_create_col_no_compression(mock_clients): """Test creating a collection with no compression.""" mock_search_client, mock_index_client, _ = mock_clients - + AzureAISearch( service_name="test-service", collection_name="no-compression-index", api_key="test-api-key", embedding_model_dims=768, - compression_type=None + compression_type=None, ) - + # Verify the index configuration args, _ = mock_index_client.create_or_update_index.call_args index = args[0] - + # Check compression configuration - should be empty assert len(index.vector_search.compressions) == 0 - + # Check profile doesn't reference compression assert index.vector_search.profiles[0].compression_name is None # --- Tests for insert method --- + def test_insert_single(azure_ai_search_instance): """Test inserting a single vector.""" instance, mock_search_client, _ = azure_ai_search_instance @@ -372,9 +371,7 @@ def test_insert_single(azure_ai_search_instance): ids = ["doc1"] # Fix: Include status_code: 201 in mock response - mock_search_client.upload_documents.return_value = [ - {"status": True, "id": "doc1", "status_code": 201} - ] + mock_search_client.upload_documents.return_value = [{"status": True, "id": "doc1", "status_code": 201}] instance.insert(vectors, payloads, ids) @@ -396,35 +393,35 @@ def test_insert_single(azure_ai_search_instance): def test_insert_multiple(azure_ai_search_instance): """Test inserting multiple vectors in one call.""" instance, mock_search_client, _ = azure_ai_search_instance - + # Create multiple vectors num_docs = 3 - vectors = [[float(i)/10, float(i+1)/10, float(i+2)/10] for i in range(num_docs)] + vectors = [[float(i) / 10, float(i + 1) / 10, float(i + 2) / 10] for i in range(num_docs)] payloads = [{"user_id": f"user{i}", "content": f"Test content {i}"} for i in range(num_docs)] ids = [f"doc{i}" for i in range(num_docs)] - + # Configure mock to return success for all documents (fix: add status_code 201) mock_search_client.upload_documents.return_value = [ {"status": True, "id": id_val, "status_code": 201} for id_val in ids ] - + # Insert the documents instance.insert(vectors, payloads, ids) - + # Verify upload_documents was called with correct documents mock_search_client.upload_documents.assert_called_once() args, _ = mock_search_client.upload_documents.call_args documents = args[0] - + # Verify all documents were included assert len(documents) == num_docs - + # Check first document assert documents[0]["id"] == "doc0" assert documents[0]["vector"] == [0.0, 0.1, 0.2] assert documents[0]["payload"] == json.dumps(payloads[0]) assert documents[0]["user_id"] == "user0" - + # Check last document assert documents[2]["id"] == "doc2" assert documents[2]["vector"] == [0.2, 0.3, 0.4] @@ -437,9 +434,7 @@ def test_insert_with_error(azure_ai_search_instance): instance, mock_search_client, _ = azure_ai_search_instance # Configure mock to return an error for one document - mock_search_client.upload_documents.return_value = [ - {"status": False, "id": "doc1", "errorMessage": "Azure error"} - ] + mock_search_client.upload_documents.return_value = [{"status": False, "id": "doc1", "errorMessage": "Azure error"}] vectors = [[0.1, 0.2, 0.3]] payloads = [{"user_id": "user1"}] @@ -454,7 +449,7 @@ def test_insert_with_error(azure_ai_search_instance): # Configure mock to return mixed success/failure for multiple documents mock_search_client.upload_documents.return_value = [ {"status": True, "id": "doc1"}, # This should not cause failure - {"status": False, "id": "doc2", "errorMessage": "Azure error"} + {"status": False, "id": "doc2", "errorMessage": "Azure error"}, ] vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] @@ -465,8 +460,9 @@ def test_insert_with_error(azure_ai_search_instance): with pytest.raises(Exception) as exc_info: instance.insert(vectors, payloads, ids) - assert "Insert failed for document doc2" in str(exc_info.value) or \ - "Insert failed for document doc1" in str(exc_info.value) + assert "Insert failed for document doc2" in str(exc_info.value) or "Insert failed for document doc1" in str( + exc_info.value + ) def test_insert_with_missing_payload_fields(azure_ai_search_instance): @@ -500,23 +496,24 @@ def test_insert_with_missing_payload_fields(azure_ai_search_instance): def test_insert_with_http_error(azure_ai_search_instance): """Test insert when Azure client throws an HTTP error.""" instance, mock_search_client, _ = azure_ai_search_instance - + # Configure mock to raise an HttpResponseError mock_search_client.upload_documents.side_effect = HttpResponseError("Azure service error") - + vectors = [[0.1, 0.2, 0.3]] payloads = [{"user_id": "user1"}] ids = ["doc1"] - + # Insert should propagate the HTTP error with pytest.raises(HttpResponseError) as exc_info: instance.insert(vectors, payloads, ids) - + assert "Azure service error" in str(exc_info.value) # --- Tests for search method --- + def test_search_basic(azure_ai_search_instance): """Test basic vector search without filters.""" instance, mock_search_client, _ = azure_ai_search_instance @@ -536,9 +533,7 @@ def test_search_basic(azure_ai_search_instance): # Search with a vector query_text = "test query" # Add a query string query_vector = [0.1, 0.2, 0.3] - results = instance.search( - query_text, query_vector, limit=5 - ) # Pass the query string + results = instance.search(query_text, query_vector, limit=5) # Pass the query string # Verify search was called correctly mock_search_client.search.assert_called_once() diff --git a/tests/vector_stores/test_elasticsearch.py b/tests/vector_stores/test_elasticsearch.py index 3107cf7c..8fd7b550 100644 --- a/tests/vector_stores/test_elasticsearch.py +++ b/tests/vector_stores/test_elasticsearch.py @@ -7,9 +7,7 @@ import dotenv try: from elasticsearch import Elasticsearch except ImportError: - raise ImportError( - "Elasticsearch requires extra dependencies. Install with `pip install elasticsearch`" - ) from None + raise ImportError("Elasticsearch requires extra dependencies. Install with `pip install elasticsearch`") from None from mem0.vector_stores.elasticsearch import ElasticsearchDB, OutputData @@ -19,20 +17,20 @@ class TestElasticsearchDB(unittest.TestCase): def setUpClass(cls): # Load environment variables before any test dotenv.load_dotenv() - + # Save original environment variables cls.original_env = { - 'ES_URL': os.getenv('ES_URL', 'http://localhost:9200'), - 'ES_USERNAME': os.getenv('ES_USERNAME', 'test_user'), - 'ES_PASSWORD': os.getenv('ES_PASSWORD', 'test_password'), - 'ES_CLOUD_ID': os.getenv('ES_CLOUD_ID', 'test_cloud_id') + "ES_URL": os.getenv("ES_URL", "http://localhost:9200"), + "ES_USERNAME": os.getenv("ES_USERNAME", "test_user"), + "ES_PASSWORD": os.getenv("ES_PASSWORD", "test_password"), + "ES_CLOUD_ID": os.getenv("ES_CLOUD_ID", "test_cloud_id"), } - + # Set test environment variables - os.environ['ES_URL'] = 'http://localhost' - os.environ['ES_USERNAME'] = 'test_user' - os.environ['ES_PASSWORD'] = 'test_password' - + os.environ["ES_URL"] = "http://localhost" + os.environ["ES_USERNAME"] = "test_user" + os.environ["ES_PASSWORD"] = "test_password" + def setUp(self): # Create a mock Elasticsearch client with proper attributes self.client_mock = MagicMock(spec=Elasticsearch) @@ -41,25 +39,25 @@ class TestElasticsearchDB(unittest.TestCase): self.client_mock.indices.create = MagicMock() self.client_mock.indices.delete = MagicMock() self.client_mock.indices.get_alias = MagicMock() - + # Start patches BEFORE creating ElasticsearchDB instance - patcher = patch('mem0.vector_stores.elasticsearch.Elasticsearch', return_value=self.client_mock) + patcher = patch("mem0.vector_stores.elasticsearch.Elasticsearch", return_value=self.client_mock) self.mock_es = patcher.start() self.addCleanup(patcher.stop) - + # Initialize ElasticsearchDB with test config and auto_create_index=False self.es_db = ElasticsearchDB( - host=os.getenv('ES_URL'), + host=os.getenv("ES_URL"), port=9200, collection_name="test_collection", embedding_model_dims=1536, - user=os.getenv('ES_USERNAME'), - password=os.getenv('ES_PASSWORD'), + user=os.getenv("ES_USERNAME"), + password=os.getenv("ES_PASSWORD"), verify_certs=False, use_ssl=False, - auto_create_index=False # Disable auto creation for tests + auto_create_index=False, # Disable auto creation for tests ) - + # Reset mock counts after initialization self.client_mock.reset_mock() @@ -80,15 +78,15 @@ class TestElasticsearchDB(unittest.TestCase): # Test when index doesn't exist self.client_mock.indices.exists.return_value = False self.es_db.create_index() - + # Verify index creation was called with correct settings self.client_mock.indices.create.assert_called_once() create_args = self.client_mock.indices.create.call_args[1] - + # Verify basic index settings self.assertEqual(create_args["index"], "test_collection") self.assertIn("mappings", create_args["body"]) - + # Verify field mappings mappings = create_args["body"]["mappings"]["properties"] self.assertEqual(mappings["text"]["type"], "text") @@ -97,53 +95,53 @@ class TestElasticsearchDB(unittest.TestCase): self.assertEqual(mappings["vector"]["index"], True) self.assertEqual(mappings["vector"]["similarity"], "cosine") self.assertEqual(mappings["metadata"]["type"], "object") - + # Reset mocks for next test self.client_mock.reset_mock() - + # Test when index already exists self.client_mock.indices.exists.return_value = True self.es_db.create_index() - + # Verify create was not called when index exists self.client_mock.indices.create.assert_not_called() def test_auto_create_index(self): # Reset mock self.client_mock.reset_mock() - + # Test with auto_create_index=True ElasticsearchDB( - host=os.getenv('ES_URL'), + host=os.getenv("ES_URL"), port=9200, collection_name="test_collection", embedding_model_dims=1536, - user=os.getenv('ES_USERNAME'), - password=os.getenv('ES_PASSWORD'), + user=os.getenv("ES_USERNAME"), + password=os.getenv("ES_PASSWORD"), verify_certs=False, use_ssl=False, - auto_create_index=True + auto_create_index=True, ) - + # Verify create_index was called during initialization self.client_mock.indices.exists.assert_called_once() - + # Reset mock self.client_mock.reset_mock() - + # Test with auto_create_index=False ElasticsearchDB( - host=os.getenv('ES_URL'), + host=os.getenv("ES_URL"), port=9200, collection_name="test_collection", embedding_model_dims=1536, - user=os.getenv('ES_USERNAME'), - password=os.getenv('ES_PASSWORD'), + user=os.getenv("ES_USERNAME"), + password=os.getenv("ES_PASSWORD"), verify_certs=False, use_ssl=False, - auto_create_index=False + auto_create_index=False, ) - + # Verify create_index was not called during initialization self.client_mock.indices.exists.assert_not_called() @@ -152,17 +150,17 @@ class TestElasticsearchDB(unittest.TestCase): vectors = [[0.1] * 1536, [0.2] * 1536] payloads = [{"key1": "value1"}, {"key2": "value2"}] ids = ["id1", "id2"] - + # Mock bulk operation - with patch('mem0.vector_stores.elasticsearch.bulk') as mock_bulk: + with patch("mem0.vector_stores.elasticsearch.bulk") as mock_bulk: mock_bulk.return_value = (2, []) # Simulate successful bulk insert - + # Perform insert results = self.es_db.insert(vectors=vectors, payloads=payloads, ids=ids) - + # Verify bulk was called mock_bulk.assert_called_once() - + # Verify bulk actions format actions = mock_bulk.call_args[0][1] self.assertEqual(len(actions), 2) @@ -170,7 +168,7 @@ class TestElasticsearchDB(unittest.TestCase): self.assertEqual(actions[0]["_id"], "id1") self.assertEqual(actions[0]["_source"]["vector"], vectors[0]) self.assertEqual(actions[0]["_source"]["metadata"], payloads[0]) - + # Verify returned objects self.assertEqual(len(results), 2) self.assertIsInstance(results[0], OutputData) @@ -182,14 +180,7 @@ class TestElasticsearchDB(unittest.TestCase): mock_response = { "hits": { "hits": [ - { - "_id": "id1", - "_score": 0.8, - "_source": { - "vector": [0.1] * 1536, - "metadata": {"key1": "value1"} - } - } + {"_id": "id1", "_score": 0.8, "_source": {"vector": [0.1] * 1536, "metadata": {"key1": "value1"}}} ] } } @@ -206,7 +197,7 @@ class TestElasticsearchDB(unittest.TestCase): # Verify search parameters self.assertEqual(search_args["index"], "test_collection") body = search_args["body"] - + # Verify KNN query structure self.assertIn("knn", body) self.assertEqual(body["knn"]["field"], "vector") @@ -235,29 +226,24 @@ class TestElasticsearchDB(unittest.TestCase): self.es_db.custom_search_query.assert_called_once_with(vectors, limit, filters) # Verify custom search query was used - self.client_mock.search.assert_called_once_with(index=self.es_db.collection_name, body={"custom_key": "custom_value"}) + self.client_mock.search.assert_called_once_with( + index=self.es_db.collection_name, body={"custom_key": "custom_value"} + ) def test_get(self): # Mock get response with correct structure mock_response = { "_id": "id1", - "_source": { - "vector": [0.1] * 1536, - "metadata": {"key": "value"}, - "text": "sample text" - } + "_source": {"vector": [0.1] * 1536, "metadata": {"key": "value"}, "text": "sample text"}, } self.client_mock.get.return_value = mock_response - + # Perform get result = self.es_db.get(vector_id="id1") - + # Verify get call - self.client_mock.get.assert_called_once_with( - index="test_collection", - id="id1" - ) - + self.client_mock.get.assert_called_once_with(index="test_collection", id="id1") + # Verify result self.assertIsNotNone(result) self.assertEqual(result.id, "id1") @@ -267,7 +253,7 @@ class TestElasticsearchDB(unittest.TestCase): def test_get_not_found(self): # Mock get raising exception self.client_mock.get.side_effect = Exception("Not found") - + # Verify get returns None when document not found result = self.es_db.get(vector_id="nonexistent") self.assertIsNone(result) @@ -277,33 +263,19 @@ class TestElasticsearchDB(unittest.TestCase): mock_response = { "hits": { "hits": [ - { - "_id": "id1", - "_source": { - "vector": [0.1] * 1536, - "metadata": {"key1": "value1"} - }, - "_score": 1.0 - }, - { - "_id": "id2", - "_source": { - "vector": [0.2] * 1536, - "metadata": {"key2": "value2"} - }, - "_score": 0.8 - } + {"_id": "id1", "_source": {"vector": [0.1] * 1536, "metadata": {"key1": "value1"}}, "_score": 1.0}, + {"_id": "id2", "_source": {"vector": [0.2] * 1536, "metadata": {"key2": "value2"}}, "_score": 0.8}, ] } } self.client_mock.search.return_value = mock_response - + # Perform list operation results = self.es_db.list(limit=10) - + # Verify search call self.client_mock.search.assert_called_once() - + # Verify results self.assertEqual(len(results), 1) # Outer list self.assertEqual(len(results[0]), 2) # Inner list @@ -316,30 +288,24 @@ class TestElasticsearchDB(unittest.TestCase): def test_delete(self): # Perform delete self.es_db.delete(vector_id="id1") - + # Verify delete call - self.client_mock.delete.assert_called_once_with( - index="test_collection", - id="id1" - ) + self.client_mock.delete.assert_called_once_with(index="test_collection", id="id1") def test_list_cols(self): # Mock indices response mock_indices = {"index1": {}, "index2": {}} self.client_mock.indices.get_alias.return_value = mock_indices - + # Get collections result = self.es_db.list_cols() - + # Verify result self.assertEqual(result, ["index1", "index2"]) def test_delete_col(self): # Delete collection self.es_db.delete_col() - + # Verify delete call - self.client_mock.indices.delete.assert_called_once_with( - index="test_collection" - ) - \ No newline at end of file + self.client_mock.indices.delete.assert_called_once_with(index="test_collection") diff --git a/tests/vector_stores/test_faiss.py b/tests/vector_stores/test_faiss.py index b44bfa1f..07652d14 100644 --- a/tests/vector_stores/test_faiss.py +++ b/tests/vector_stores/test_faiss.py @@ -21,9 +21,9 @@ def mock_faiss_index(): def faiss_instance(mock_faiss_index): with tempfile.TemporaryDirectory() as temp_dir: # Mock the faiss index creation - with patch('faiss.IndexFlatL2', return_value=mock_faiss_index): + with patch("faiss.IndexFlatL2", return_value=mock_faiss_index): # Mock the faiss.write_index function - with patch('faiss.write_index'): + with patch("faiss.write_index"): # Create a FAISS instance with a temporary directory faiss_store = FAISS( collection_name="test_collection", @@ -37,14 +37,14 @@ def faiss_instance(mock_faiss_index): def test_create_col(faiss_instance, mock_faiss_index): # Test creating a collection with euclidean distance - with patch('faiss.IndexFlatL2', return_value=mock_faiss_index) as mock_index_flat_l2: - with patch('faiss.write_index'): + with patch("faiss.IndexFlatL2", return_value=mock_faiss_index) as mock_index_flat_l2: + with patch("faiss.write_index"): faiss_instance.create_col(name="new_collection") mock_index_flat_l2.assert_called_once_with(faiss_instance.embedding_model_dims) - + # Test creating a collection with inner product distance - with patch('faiss.IndexFlatIP', return_value=mock_faiss_index) as mock_index_flat_ip: - with patch('faiss.write_index'): + with patch("faiss.IndexFlatIP", return_value=mock_faiss_index) as mock_index_flat_ip: + with patch("faiss.write_index"): faiss_instance.create_col(name="new_collection", distance="inner_product") mock_index_flat_ip.assert_called_once_with(faiss_instance.embedding_model_dims) @@ -54,21 +54,21 @@ def test_insert(faiss_instance, mock_faiss_index): vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] payloads = [{"name": "vector1"}, {"name": "vector2"}] ids = ["id1", "id2"] - + # Mock the numpy array conversion - with patch('numpy.array', return_value=np.array(vectors, dtype=np.float32)) as mock_np_array: + with patch("numpy.array", return_value=np.array(vectors, dtype=np.float32)) as mock_np_array: # Mock index.add mock_faiss_index.add.return_value = None - + # Call insert faiss_instance.insert(vectors=vectors, payloads=payloads, ids=ids) - + # Verify numpy.array was called mock_np_array.assert_called_once_with(vectors, dtype=np.float32) - + # Verify index.add was called mock_faiss_index.add.assert_called_once() - + # Verify docstore and index_to_id were updated assert faiss_instance.docstore["id1"] == {"name": "vector1"} assert faiss_instance.docstore["id2"] == {"name": "vector2"} @@ -79,39 +79,36 @@ def test_insert(faiss_instance, mock_faiss_index): def test_search(faiss_instance, mock_faiss_index): # Prepare test data query_vector = [0.1, 0.2, 0.3] - + # Setup the docstore and index_to_id mapping - faiss_instance.docstore = { - "id1": {"name": "vector1"}, - "id2": {"name": "vector2"} - } + faiss_instance.docstore = {"id1": {"name": "vector1"}, "id2": {"name": "vector2"}} faiss_instance.index_to_id = {0: "id1", 1: "id2"} - + # First, create the mock for the search return values search_scores = np.array([[0.9, 0.8]]) search_indices = np.array([[0, 1]]) mock_faiss_index.search.return_value = (search_scores, search_indices) - + # Then patch numpy.array only for the query vector conversion - with patch('numpy.array') as mock_np_array: + with patch("numpy.array") as mock_np_array: mock_np_array.return_value = np.array(query_vector, dtype=np.float32) - + # Then patch _parse_output to return the expected results expected_results = [ OutputData(id="id1", score=0.9, payload={"name": "vector1"}), - OutputData(id="id2", score=0.8, payload={"name": "vector2"}) + OutputData(id="id2", score=0.8, payload={"name": "vector2"}), ] - - with patch.object(faiss_instance, '_parse_output', return_value=expected_results): + + with patch.object(faiss_instance, "_parse_output", return_value=expected_results): # Call search results = faiss_instance.search(query="test query", vectors=query_vector, limit=2) - + # Verify numpy.array was called (but we don't check exact call arguments since it's complex) assert mock_np_array.called - + # Verify index.search was called mock_faiss_index.search.assert_called_once() - + # Verify results assert len(results) == 2 assert results[0].id == "id1" @@ -125,47 +122,41 @@ def test_search(faiss_instance, mock_faiss_index): def test_search_with_filters(faiss_instance, mock_faiss_index): # Prepare test data query_vector = [0.1, 0.2, 0.3] - + # Setup the docstore and index_to_id mapping - faiss_instance.docstore = { - "id1": {"name": "vector1", "category": "A"}, - "id2": {"name": "vector2", "category": "B"} - } + faiss_instance.docstore = {"id1": {"name": "vector1", "category": "A"}, "id2": {"name": "vector2", "category": "B"}} faiss_instance.index_to_id = {0: "id1", 1: "id2"} - + # First set up the search return values search_scores = np.array([[0.9, 0.8]]) search_indices = np.array([[0, 1]]) mock_faiss_index.search.return_value = (search_scores, search_indices) - + # Patch numpy.array for query vector conversion - with patch('numpy.array') as mock_np_array: + with patch("numpy.array") as mock_np_array: mock_np_array.return_value = np.array(query_vector, dtype=np.float32) - + # Directly mock the _parse_output method to return our expected values # We're simulating that _parse_output filters to just the first result all_results = [ OutputData(id="id1", score=0.9, payload={"name": "vector1", "category": "A"}), - OutputData(id="id2", score=0.8, payload={"name": "vector2", "category": "B"}) + OutputData(id="id2", score=0.8, payload={"name": "vector2", "category": "B"}), ] # Replace the _apply_filters method to handle our test case - with patch.object(faiss_instance, '_parse_output', return_value=all_results): - with patch.object(faiss_instance, '_apply_filters', side_effect=lambda p, f: p.get("category") == "A"): + with patch.object(faiss_instance, "_parse_output", return_value=all_results): + with patch.object(faiss_instance, "_apply_filters", side_effect=lambda p, f: p.get("category") == "A"): # Call search with filters results = faiss_instance.search( - query="test query", - vectors=query_vector, - limit=2, - filters={"category": "A"} + query="test query", vectors=query_vector, limit=2, filters={"category": "A"} ) - + # Verify numpy.array was called assert mock_np_array.called - + # Verify index.search was called mock_faiss_index.search.assert_called_once() - + # Verify filtered results - since we've mocked everything, # we should get just the result we want assert len(results) == 1 @@ -176,15 +167,12 @@ def test_search_with_filters(faiss_instance, mock_faiss_index): def test_delete(faiss_instance): # Setup the docstore and index_to_id mapping - faiss_instance.docstore = { - "id1": {"name": "vector1"}, - "id2": {"name": "vector2"} - } + faiss_instance.docstore = {"id1": {"name": "vector1"}, "id2": {"name": "vector2"}} faiss_instance.index_to_id = {0: "id1", 1: "id2"} - + # Call delete faiss_instance.delete(vector_id="id1") - + # Verify the vector was removed from docstore and index_to_id assert "id1" not in faiss_instance.docstore assert 0 not in faiss_instance.index_to_id @@ -194,23 +182,20 @@ def test_delete(faiss_instance): def test_update(faiss_instance, mock_faiss_index): # Setup the docstore and index_to_id mapping - faiss_instance.docstore = { - "id1": {"name": "vector1"}, - "id2": {"name": "vector2"} - } + faiss_instance.docstore = {"id1": {"name": "vector1"}, "id2": {"name": "vector2"}} faiss_instance.index_to_id = {0: "id1", 1: "id2"} - + # Test updating payload only faiss_instance.update(vector_id="id1", payload={"name": "updated_vector1"}) assert faiss_instance.docstore["id1"] == {"name": "updated_vector1"} - + # Test updating vector # This requires mocking the delete and insert methods - with patch.object(faiss_instance, 'delete') as mock_delete: - with patch.object(faiss_instance, 'insert') as mock_insert: + with patch.object(faiss_instance, "delete") as mock_delete: + with patch.object(faiss_instance, "insert") as mock_insert: new_vector = [0.7, 0.8, 0.9] faiss_instance.update(vector_id="id2", vector=new_vector) - + # Verify delete and insert were called # Match the actual call signature (positional arg instead of keyword) mock_delete.assert_called_once_with("id2") @@ -219,17 +204,14 @@ def test_update(faiss_instance, mock_faiss_index): def test_get(faiss_instance): # Setup the docstore - faiss_instance.docstore = { - "id1": {"name": "vector1"}, - "id2": {"name": "vector2"} - } - + faiss_instance.docstore = {"id1": {"name": "vector1"}, "id2": {"name": "vector2"}} + # Test getting an existing vector result = faiss_instance.get(vector_id="id1") assert result.id == "id1" assert result.payload == {"name": "vector1"} assert result.score is None - + # Test getting a non-existent vector result = faiss_instance.get(vector_id="id3") assert result is None @@ -240,18 +222,18 @@ def test_list(faiss_instance): faiss_instance.docstore = { "id1": {"name": "vector1", "category": "A"}, "id2": {"name": "vector2", "category": "B"}, - "id3": {"name": "vector3", "category": "A"} + "id3": {"name": "vector3", "category": "A"}, } - + # Test listing all vectors results = faiss_instance.list() # Fix the expected result - the list method returns a list of lists assert len(results[0]) == 3 - + # Test listing with a limit results = faiss_instance.list(limit=2) assert len(results[0]) == 2 - + # Test listing with filters results = faiss_instance.list(filters={"category": "A"}) assert len(results[0]) == 2 @@ -263,10 +245,10 @@ def test_col_info(faiss_instance, mock_faiss_index): # Mock index attributes mock_faiss_index.ntotal = 5 mock_faiss_index.d = 128 - + # Get collection info info = faiss_instance.col_info() - + # Verify the returned info assert info["name"] == "test_collection" assert info["count"] == 5 @@ -276,14 +258,14 @@ def test_col_info(faiss_instance, mock_faiss_index): def test_delete_col(faiss_instance): # Mock the os.remove function - with patch('os.remove') as mock_remove: - with patch('os.path.exists', return_value=True): + with patch("os.remove") as mock_remove: + with patch("os.path.exists", return_value=True): # Call delete_col faiss_instance.delete_col() - + # Verify os.remove was called twice (for index and docstore files) assert mock_remove.call_count == 2 - + # Verify the internal state was reset assert faiss_instance.index is None assert faiss_instance.docstore == {} @@ -293,17 +275,17 @@ def test_delete_col(faiss_instance): def test_normalize_L2(faiss_instance, mock_faiss_index): # Setup a FAISS instance with normalize_L2=True faiss_instance.normalize_L2 = True - + # Prepare test data vectors = [[0.1, 0.2, 0.3]] - + # Mock numpy array conversion # Mock numpy array conversion - with patch('numpy.array', return_value=np.array(vectors, dtype=np.float32)): + with patch("numpy.array", return_value=np.array(vectors, dtype=np.float32)): # Mock faiss.normalize_L2 - with patch('faiss.normalize_L2') as mock_normalize: + with patch("faiss.normalize_L2") as mock_normalize: # Call insert faiss_instance.insert(vectors=vectors, ids=["id1"]) - + # Verify faiss.normalize_L2 was called mock_normalize.assert_called_once() diff --git a/tests/vector_stores/test_langchain_vector_store.py b/tests/vector_stores/test_langchain_vector_store.py index 9c9e6ca2..6e156ec2 100644 --- a/tests/vector_stores/test_langchain_vector_store.py +++ b/tests/vector_stores/test_langchain_vector_store.py @@ -11,11 +11,13 @@ def mock_langchain_client(): with patch("langchain_community.vectorstores.VectorStore") as mock_client: yield mock_client + @pytest.fixture def langchain_instance(mock_langchain_client): mock_client = Mock(spec=VectorStore) return Langchain(client=mock_client, collection_name="test_collection") + def test_insert_vectors(langchain_instance): # Test data vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] @@ -25,48 +27,31 @@ def test_insert_vectors(langchain_instance): # Test with add_embeddings method langchain_instance.client.add_embeddings = Mock() langchain_instance.insert(vectors=vectors, payloads=payloads, ids=ids) - langchain_instance.client.add_embeddings.assert_called_once_with( - embeddings=vectors, - metadatas=payloads, - ids=ids - ) + langchain_instance.client.add_embeddings.assert_called_once_with(embeddings=vectors, metadatas=payloads, ids=ids) # Test with add_texts method delattr(langchain_instance.client, "add_embeddings") # Remove attribute completely langchain_instance.client.add_texts = Mock() langchain_instance.insert(vectors=vectors, payloads=payloads, ids=ids) - langchain_instance.client.add_texts.assert_called_once_with( - texts=["text1", "text2"], - metadatas=payloads, - ids=ids - ) + langchain_instance.client.add_texts.assert_called_once_with(texts=["text1", "text2"], metadatas=payloads, ids=ids) # Test with empty payloads langchain_instance.client.add_texts.reset_mock() langchain_instance.insert(vectors=vectors, payloads=None, ids=ids) - langchain_instance.client.add_texts.assert_called_once_with( - texts=["", ""], - metadatas=None, - ids=ids - ) + langchain_instance.client.add_texts.assert_called_once_with(texts=["", ""], metadatas=None, ids=ids) + def test_search_vectors(langchain_instance): # Mock search results - mock_docs = [ - Mock(metadata={"name": "vector1"}, id="id1"), - Mock(metadata={"name": "vector2"}, id="id2") - ] + mock_docs = [Mock(metadata={"name": "vector1"}, id="id1"), Mock(metadata={"name": "vector2"}, id="id2")] langchain_instance.client.similarity_search_by_vector.return_value = mock_docs # Test search without filters vectors = [[0.1, 0.2, 0.3]] results = langchain_instance.search(query="", vectors=vectors, limit=2) - - langchain_instance.client.similarity_search_by_vector.assert_called_once_with( - embedding=vectors, - k=2 - ) - + + langchain_instance.client.similarity_search_by_vector.assert_called_once_with(embedding=vectors, k=2) + assert len(results) == 2 assert results[0].id == "id1" assert results[0].payload == {"name": "vector1"} @@ -76,11 +61,8 @@ def test_search_vectors(langchain_instance): # Test search with filters filters = {"name": "vector1"} langchain_instance.search(query="", vectors=vectors, limit=2, filters=filters) - langchain_instance.client.similarity_search_by_vector.assert_called_with( - embedding=vectors, - k=2, - filter=filters - ) + langchain_instance.client.similarity_search_by_vector.assert_called_with(embedding=vectors, k=2, filter=filters) + def test_get_vector(langchain_instance): # Mock get result @@ -90,7 +72,7 @@ def test_get_vector(langchain_instance): # Test get existing vector result = langchain_instance.get("id1") langchain_instance.client.get_by_ids.assert_called_once_with(["id1"]) - + assert result is not None assert result.id == "id1" assert result.payload == {"name": "vector1"} diff --git a/tests/vector_stores/test_opensearch.py b/tests/vector_stores/test_opensearch.py index df155a7d..043c9efe 100644 --- a/tests/vector_stores/test_opensearch.py +++ b/tests/vector_stores/test_opensearch.py @@ -8,9 +8,7 @@ import pytest try: from opensearchpy import AWSV4SignerAuth, OpenSearch except ImportError: - raise ImportError( - "OpenSearch requires extra dependencies. Install with `pip install opensearch-py`" - ) from None + raise ImportError("OpenSearch requires extra dependencies. Install with `pip install opensearch-py`") from None from mem0.vector_stores.opensearch import OpenSearchDB @@ -20,13 +18,13 @@ class TestOpenSearchDB(unittest.TestCase): def setUpClass(cls): dotenv.load_dotenv() cls.original_env = { - 'OS_URL': os.getenv('OS_URL', 'http://localhost:9200'), - 'OS_USERNAME': os.getenv('OS_USERNAME', 'test_user'), - 'OS_PASSWORD': os.getenv('OS_PASSWORD', 'test_password') + "OS_URL": os.getenv("OS_URL", "http://localhost:9200"), + "OS_USERNAME": os.getenv("OS_USERNAME", "test_user"), + "OS_PASSWORD": os.getenv("OS_PASSWORD", "test_password"), } - os.environ['OS_URL'] = 'http://localhost' - os.environ['OS_USERNAME'] = 'test_user' - os.environ['OS_PASSWORD'] = 'test_password' + os.environ["OS_URL"] = "http://localhost" + os.environ["OS_USERNAME"] = "test_user" + os.environ["OS_PASSWORD"] = "test_password" def setUp(self): self.client_mock = MagicMock(spec=OpenSearch) @@ -40,19 +38,19 @@ class TestOpenSearchDB(unittest.TestCase): self.client_mock.delete = MagicMock() self.client_mock.search = MagicMock() - patcher = patch('mem0.vector_stores.opensearch.OpenSearch', return_value=self.client_mock) + patcher = patch("mem0.vector_stores.opensearch.OpenSearch", return_value=self.client_mock) self.mock_os = patcher.start() self.addCleanup(patcher.stop) self.os_db = OpenSearchDB( - host=os.getenv('OS_URL'), + host=os.getenv("OS_URL"), port=9200, collection_name="test_collection", embedding_model_dims=1536, - user=os.getenv('OS_USERNAME'), - password=os.getenv('OS_PASSWORD'), + user=os.getenv("OS_USERNAME"), + password=os.getenv("OS_PASSWORD"), verify_certs=False, - use_ssl=False + use_ssl=False, ) self.client_mock.reset_mock() @@ -86,29 +84,29 @@ class TestOpenSearchDB(unittest.TestCase): vectors = [[0.1] * 1536, [0.2] * 1536] payloads = [{"key1": "value1"}, {"key2": "value2"}] ids = ["id1", "id2"] - + # Mock the index method self.client_mock.index = MagicMock() - + results = self.os_db.insert(vectors=vectors, payloads=payloads, ids=ids) - + # Verify index was called twice (once for each vector) self.assertEqual(self.client_mock.index.call_count, 2) - + # Check first call first_call = self.client_mock.index.call_args_list[0] self.assertEqual(first_call[1]["index"], "test_collection") self.assertEqual(first_call[1]["body"]["vector_field"], vectors[0]) self.assertEqual(first_call[1]["body"]["payload"], payloads[0]) self.assertEqual(first_call[1]["body"]["id"], ids[0]) - + # Check second call second_call = self.client_mock.index.call_args_list[1] self.assertEqual(second_call[1]["index"], "test_collection") self.assertEqual(second_call[1]["body"]["vector_field"], vectors[1]) self.assertEqual(second_call[1]["body"]["payload"], payloads[1]) self.assertEqual(second_call[1]["body"]["id"], ids[1]) - + # Check results self.assertEqual(len(results), 2) self.assertEqual(results[0].id, "id1") @@ -132,7 +130,7 @@ class TestOpenSearchDB(unittest.TestCase): self.client_mock.search.return_value = {"hits": {"hits": []}} result = self.os_db.get("nonexistent") self.assertIsNone(result) - + def test_update(self): vector = [0.3] * 1536 payload = {"key3": "value3"} @@ -152,7 +150,17 @@ class TestOpenSearchDB(unittest.TestCase): self.assertEqual(result, ["test_collection"]) def test_search(self): - mock_response = {"hits": {"hits": [{"_id": "id1", "_score": 0.8, "_source": {"vector_field": [0.1] * 1536, "id": "id1", "payload": {"key1": "value1"}}}]}} + mock_response = { + "hits": { + "hits": [ + { + "_id": "id1", + "_score": 0.8, + "_source": {"vector_field": [0.1] * 1536, "id": "id1", "payload": {"key1": "value1"}}, + } + ] + } + } self.client_mock.search.return_value = mock_response vectors = [[0.1] * 1536] results = self.os_db.search(query="", vectors=vectors, limit=5) @@ -179,12 +187,11 @@ class TestOpenSearchDB(unittest.TestCase): self.os_db.delete_col() self.client_mock.indices.delete.assert_called_once_with(index="test_collection") - def test_init_with_http_auth(self): mock_credentials = MagicMock() mock_signer = AWSV4SignerAuth(mock_credentials, "us-east-1", "es") - with patch('mem0.vector_stores.opensearch.OpenSearch') as mock_opensearch: + with patch("mem0.vector_stores.opensearch.OpenSearch") as mock_opensearch: OpenSearchDB( host="localhost", port=9200, @@ -192,7 +199,7 @@ class TestOpenSearchDB(unittest.TestCase): embedding_model_dims=1536, http_auth=mock_signer, verify_certs=True, - use_ssl=True + use_ssl=True, ) # Verify OpenSearch was initialized with correct params @@ -202,5 +209,5 @@ class TestOpenSearchDB(unittest.TestCase): use_ssl=True, verify_certs=True, connection_class=unittest.mock.ANY, - pool_maxsize=20 - ) \ No newline at end of file + pool_maxsize=20, + ) diff --git a/tests/vector_stores/test_pinecone.py b/tests/vector_stores/test_pinecone.py index 2ff5410d..2cfaf996 100644 --- a/tests/vector_stores/test_pinecone.py +++ b/tests/vector_stores/test_pinecone.py @@ -12,6 +12,7 @@ def mock_pinecone_client(): client.list_indexes.return_value.names.return_value = [] return client + @pytest.fixture def pinecone_db(mock_pinecone_client): return PineconeDB( @@ -25,13 +26,14 @@ def pinecone_db(mock_pinecone_client): hybrid_search=False, metric="cosine", batch_size=100, - extra_params=None + extra_params=None, ) + def test_create_col_existing_index(mock_pinecone_client): # Set up the mock before creating the PineconeDB object mock_pinecone_client.list_indexes.return_value.names.return_value = ["test_index"] - + pinecone_db = PineconeDB( collection_name="test_index", embedding_model_dims=128, @@ -43,21 +45,23 @@ def test_create_col_existing_index(mock_pinecone_client): hybrid_search=False, metric="cosine", batch_size=100, - extra_params=None + extra_params=None, ) - + # Reset the mock to verify it wasn't called during the test mock_pinecone_client.create_index.reset_mock() - + pinecone_db.create_col(128, "cosine") - + mock_pinecone_client.create_index.assert_not_called() + def test_create_col_new_index(pinecone_db, mock_pinecone_client): mock_pinecone_client.list_indexes.return_value.names.return_value = [] pinecone_db.create_col(128, "cosine") mock_pinecone_client.create_index.assert_called() + def test_insert_vectors(pinecone_db): vectors = [[0.1] * 128, [0.2] * 128] payloads = [{"name": "vector1"}, {"name": "vector2"}] @@ -65,56 +69,61 @@ def test_insert_vectors(pinecone_db): pinecone_db.insert(vectors, payloads, ids) pinecone_db.index.upsert.assert_called() + def test_search_vectors(pinecone_db): pinecone_db.index.query.return_value.matches = [{"id": "id1", "score": 0.9, "metadata": {"name": "vector1"}}] - results = pinecone_db.search("test query",[0.1] * 128, limit=1) + results = pinecone_db.search("test query", [0.1] * 128, limit=1) assert len(results) == 1 assert results[0].id == "id1" assert results[0].score == 0.9 + def test_update_vector(pinecone_db): pinecone_db.update("id1", vector=[0.5] * 128, payload={"name": "updated"}) pinecone_db.index.upsert.assert_called() + def test_get_vector_found(pinecone_db): # Looking at the _parse_output method, it expects a Vector object # or a list of dictionaries, not a dictionary with an 'id' field - + # Create a mock Vector object from pinecone.data.dataclasses.vector import Vector - mock_vector = Vector( - id="id1", - values=[0.1] * 128, - metadata={"name": "vector1"} - ) - + + mock_vector = Vector(id="id1", values=[0.1] * 128, metadata={"name": "vector1"}) + # Mock the fetch method to return the mock response object mock_response = MagicMock() mock_response.vectors = {"id1": mock_vector} pinecone_db.index.fetch.return_value = mock_response - + result = pinecone_db.get("id1") assert result is not None assert result.id == "id1" assert result.payload == {"name": "vector1"} + def test_delete_vector(pinecone_db): pinecone_db.delete("id1") pinecone_db.index.delete.assert_called_with(ids=["id1"]) + def test_get_vector_not_found(pinecone_db): pinecone_db.index.fetch.return_value.vectors = {} result = pinecone_db.get("id1") assert result is None + def test_list_cols(pinecone_db): pinecone_db.list_cols() pinecone_db.client.list_indexes.assert_called() + def test_delete_col(pinecone_db): pinecone_db.delete_col() pinecone_db.client.delete_index.assert_called_with("test_index") + def test_col_info(pinecone_db): pinecone_db.col_info() pinecone_db.client.describe_index.assert_called_with("test_index") diff --git a/tests/vector_stores/test_supabase.py b/tests/vector_stores/test_supabase.py index f94c203e..e051ccf1 100644 --- a/tests/vector_stores/test_supabase.py +++ b/tests/vector_stores/test_supabase.py @@ -37,7 +37,7 @@ def supabase_instance(mock_vecs_client, mock_collection): index_method=IndexMethod.HNSW, index_measure=IndexMeasure.COSINE, ) - + # Manually set the collection attribute since we're mocking the initialization instance.collection = mock_collection return instance @@ -46,14 +46,8 @@ def supabase_instance(mock_vecs_client, mock_collection): def test_create_col(supabase_instance, mock_vecs_client, mock_collection): supabase_instance.create_col(1536) - mock_vecs_client.return_value.get_or_create_collection.assert_called_with( - name="test_collection", - dimension=1536 - ) - mock_collection.create_index.assert_called_with( - method="hnsw", - measure="cosine_distance" - ) + mock_vecs_client.return_value.get_or_create_collection.assert_called_with(name="test_collection", dimension=1536) + mock_collection.create_index.assert_called_with(method="hnsw", measure="cosine_distance") def test_insert_vectors(supabase_instance, mock_collection): @@ -63,18 +57,12 @@ def test_insert_vectors(supabase_instance, mock_collection): supabase_instance.insert(vectors=vectors, payloads=payloads, ids=ids) - expected_records = [ - ("id1", [0.1, 0.2, 0.3], {"name": "vector1"}), - ("id2", [0.4, 0.5, 0.6], {"name": "vector2"}) - ] + expected_records = [("id1", [0.1, 0.2, 0.3], {"name": "vector1"}), ("id2", [0.4, 0.5, 0.6], {"name": "vector2"})] mock_collection.upsert.assert_called_once_with(expected_records) def test_search_vectors(supabase_instance, mock_collection): - mock_results = [ - ("id1", 0.9, {"name": "vector1"}), - ("id2", 0.8, {"name": "vector2"}) - ] + mock_results = [("id1", 0.9, {"name": "vector1"}), ("id2", 0.8, {"name": "vector2"})] mock_collection.query.return_value = mock_results vectors = [[0.1, 0.2, 0.3]] @@ -82,11 +70,7 @@ def test_search_vectors(supabase_instance, mock_collection): results = supabase_instance.search(query="", vectors=vectors, limit=2, filters=filters) mock_collection.query.assert_called_once_with( - data=vectors, - limit=2, - filters={"category": {"$eq": "test"}}, - include_metadata=True, - include_value=True + data=vectors, limit=2, filters={"category": {"$eq": "test"}}, include_metadata=True, include_value=True ) assert len(results) == 2 @@ -129,11 +113,8 @@ def test_get_vector(supabase_instance, mock_collection): def test_list_vectors(supabase_instance, mock_collection): mock_query_results = [("id1", 0.9, {}), ("id2", 0.8, {})] - mock_fetch_results = [ - ("id1", [0.1, 0.2, 0.3], {"name": "vector1"}), - ("id2", [0.4, 0.5, 0.6], {"name": "vector2"}) - ] - + mock_fetch_results = [("id1", [0.1, 0.2, 0.3], {"name": "vector1"}), ("id2", [0.4, 0.5, 0.6], {"name": "vector2"})] + mock_collection.query.return_value = mock_query_results mock_collection.fetch.return_value = mock_fetch_results @@ -153,10 +134,7 @@ def test_col_info(supabase_instance, mock_collection): "name": "test_collection", "count": 100, "dimension": 1536, - "index": { - "method": "hnsw", - "metric": "cosine_distance" - } + "index": {"method": "hnsw", "metric": "cosine_distance"}, } @@ -168,10 +146,7 @@ def test_preprocess_filters(supabase_instance): # Test multiple filters multi_filter = {"category": "test", "type": "document"} assert supabase_instance._preprocess_filters(multi_filter) == { - "$and": [ - {"category": {"$eq": "test"}}, - {"type": {"$eq": "document"}} - ] + "$and": [{"category": {"$eq": "test"}}, {"type": {"$eq": "document"}}] } # Test None filters diff --git a/tests/vector_stores/test_upstash_vector.py b/tests/vector_stores/test_upstash_vector.py index e5a38846..ad342341 100644 --- a/tests/vector_stores/test_upstash_vector.py +++ b/tests/vector_stores/test_upstash_vector.py @@ -29,9 +29,7 @@ def upstash_instance(mock_index): @pytest.fixture def upstash_instance_with_embeddings(mock_index): - return UpstashVector( - client=mock_index.return_value, collection_name="ns", enable_embeddings=True - ) + return UpstashVector(client=mock_index.return_value, collection_name="ns", enable_embeddings=True) def test_insert_vectors(upstash_instance, mock_index): @@ -52,12 +50,8 @@ def test_insert_vectors(upstash_instance, mock_index): def test_search_vectors(upstash_instance, mock_index): mock_result = [ - QueryResult( - id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None - ), - QueryResult( - id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data=None - ), + QueryResult(id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None), + QueryResult(id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data=None), ] upstash_instance.client.query_many.return_value = [mock_result] @@ -93,9 +87,7 @@ def test_delete_vector(upstash_instance): upstash_instance.delete(vector_id=vector_id) - upstash_instance.client.delete.assert_called_once_with( - ids=[vector_id], namespace="ns" - ) + upstash_instance.client.delete.assert_called_once_with(ids=[vector_id], namespace="ns") def test_update_vector(upstash_instance): @@ -115,18 +107,12 @@ def test_update_vector(upstash_instance): def test_get_vector(upstash_instance): - mock_result = [ - QueryResult( - id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None - ) - ] + mock_result = [QueryResult(id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None)] upstash_instance.client.fetch.return_value = mock_result result = upstash_instance.get(vector_id="id1") - upstash_instance.client.fetch.assert_called_once_with( - ids=["id1"], namespace="ns", include_metadata=True - ) + upstash_instance.client.fetch.assert_called_once_with(ids=["id1"], namespace="ns", include_metadata=True) assert result.id == "id1" assert result.payload == {"name": "vector1"} @@ -134,15 +120,9 @@ def test_get_vector(upstash_instance): def test_list_vectors(upstash_instance): mock_result = [ - QueryResult( - id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None - ), - QueryResult( - id="id2", score=None, vector=None, metadata={"name": "vector2"}, data=None - ), - QueryResult( - id="id3", score=None, vector=None, metadata={"name": "vector3"}, data=None - ), + QueryResult(id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None), + QueryResult(id="id2", score=None, vector=None, metadata={"name": "vector2"}, data=None), + QueryResult(id="id3", score=None, vector=None, metadata={"name": "vector3"}, data=None), ] handler = MagicMock() @@ -204,12 +184,8 @@ def test_insert_vectors_with_embeddings(upstash_instance_with_embeddings, mock_i def test_search_vectors_with_embeddings(upstash_instance_with_embeddings, mock_index): mock_result = [ - QueryResult( - id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data="data1" - ), - QueryResult( - id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data="data2" - ), + QueryResult(id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data="data1"), + QueryResult(id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data="data2"), ] upstash_instance_with_embeddings.client.query.return_value = mock_result @@ -260,9 +236,7 @@ def test_insert_vectors_with_embeddings_missing_data(upstash_instance_with_embed ValueError, match="When embeddings are enabled, all payloads must contain a 'data' field", ): - upstash_instance_with_embeddings.insert( - vectors=vectors, payloads=payloads, ids=ids - ) + upstash_instance_with_embeddings.insert(vectors=vectors, payloads=payloads, ids=ids) def test_update_vector_with_embeddings_missing_data(upstash_instance_with_embeddings): @@ -316,18 +290,12 @@ def test_get_vector_not_found(upstash_instance): result = upstash_instance.get(vector_id="nonexistent") - upstash_instance.client.fetch.assert_called_once_with( - ids=["nonexistent"], namespace="ns", include_metadata=True - ) + upstash_instance.client.fetch.assert_called_once_with(ids=["nonexistent"], namespace="ns", include_metadata=True) assert result is None def test_search_vectors_empty_filters(upstash_instance): - mock_result = [ - QueryResult( - id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None - ) - ] + mock_result = [QueryResult(id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None)] upstash_instance.client.query_many.return_value = [mock_result] vectors = [[0.1, 0.2, 0.3]] diff --git a/tests/vector_stores/test_vertex_ai_vector_search.py b/tests/vector_stores/test_vertex_ai_vector_search.py index 3a1ab50d..d0d1f4c9 100644 --- a/tests/vector_stores/test_vertex_ai_vector_search.py +++ b/tests/vector_stores/test_vertex_ai_vector_search.py @@ -14,47 +14,50 @@ from mem0.vector_stores.vertex_ai_vector_search import GoogleMatchingEngine @pytest.fixture def mock_vertex_ai(): - with patch('google.cloud.aiplatform.MatchingEngineIndex') as mock_index, \ - patch('google.cloud.aiplatform.MatchingEngineIndexEndpoint') as mock_endpoint, \ - patch('google.cloud.aiplatform.init') as mock_init: + with ( + patch("google.cloud.aiplatform.MatchingEngineIndex") as mock_index, + patch("google.cloud.aiplatform.MatchingEngineIndexEndpoint") as mock_endpoint, + patch("google.cloud.aiplatform.init") as mock_init, + ): mock_index_instance = Mock() mock_endpoint_instance = Mock() yield { - 'index': mock_index_instance, - 'endpoint': mock_endpoint_instance, - 'init': mock_init, - 'index_class': mock_index, - 'endpoint_class': mock_endpoint + "index": mock_index_instance, + "endpoint": mock_endpoint_instance, + "init": mock_init, + "index_class": mock_index, + "endpoint_class": mock_endpoint, } + @pytest.fixture def config(): return GoogleMatchingEngineConfig( - project_id='test-project', - project_number='123456789', - region='us-central1', - endpoint_id='test-endpoint', - index_id='test-index', - deployment_index_id='test-deployment', - collection_name='test-collection', - vector_search_api_endpoint='test.vertexai.goog' + project_id="test-project", + project_number="123456789", + region="us-central1", + endpoint_id="test-endpoint", + index_id="test-index", + deployment_index_id="test-deployment", + collection_name="test-collection", + vector_search_api_endpoint="test.vertexai.goog", ) + @pytest.fixture def vector_store(config, mock_vertex_ai): - mock_vertex_ai['index_class'].return_value = mock_vertex_ai['index'] - mock_vertex_ai['endpoint_class'].return_value = mock_vertex_ai['endpoint'] + mock_vertex_ai["index_class"].return_value = mock_vertex_ai["index"] + mock_vertex_ai["endpoint_class"].return_value = mock_vertex_ai["endpoint"] return GoogleMatchingEngine(**config.model_dump()) + def test_initialization(vector_store, mock_vertex_ai, config): """Test proper initialization of GoogleMatchingEngine""" - mock_vertex_ai['init'].assert_called_once_with( - project=config.project_id, - location=config.region - ) + mock_vertex_ai["init"].assert_called_once_with(project=config.project_id, location=config.region) expected_index_path = f"projects/{config.project_number}/locations/{config.region}/indexes/{config.index_id}" - mock_vertex_ai['index_class'].assert_called_once_with(index_name=expected_index_path) + mock_vertex_ai["index_class"].assert_called_once_with(index_name=expected_index_path) + def test_insert_vectors(vector_store, mock_vertex_ai): """Test inserting vectors with payloads""" @@ -64,13 +67,14 @@ def test_insert_vectors(vector_store, mock_vertex_ai): vector_store.insert(vectors=vectors, payloads=payloads, ids=ids) - mock_vertex_ai['index'].upsert_datapoints.assert_called_once() - call_args = mock_vertex_ai['index'].upsert_datapoints.call_args[1] - assert len(call_args['datapoints']) == 1 - datapoint_str = str(call_args['datapoints'][0]) + mock_vertex_ai["index"].upsert_datapoints.assert_called_once() + call_args = mock_vertex_ai["index"].upsert_datapoints.call_args[1] + assert len(call_args["datapoints"]) == 1 + datapoint_str = str(call_args["datapoints"][0]) assert "test-id" in datapoint_str assert "0.1" in datapoint_str and "0.2" in datapoint_str and "0.3" in datapoint_str + def test_search_vectors(vector_store, mock_vertex_ai): """Test searching vectors with filters""" vectors = [[0.1, 0.2, 0.3]] @@ -85,7 +89,7 @@ def test_search_vectors(vector_store, mock_vertex_ai): mock_restrict.allow_list = ["test_user"] mock_restrict.name = "user_id" mock_restrict.allow_tokens = ["test_user"] - + mock_datapoint.restricts = [mock_restrict] mock_neighbor = Mock() @@ -94,16 +98,16 @@ def test_search_vectors(vector_store, mock_vertex_ai): mock_neighbor.datapoint = mock_datapoint mock_neighbor.restricts = [mock_restrict] - mock_vertex_ai['endpoint'].find_neighbors.return_value = [[mock_neighbor]] + mock_vertex_ai["endpoint"].find_neighbors.return_value = [[mock_neighbor]] results = vector_store.search(query="", vectors=vectors, filters=filters, limit=1) - mock_vertex_ai['endpoint'].find_neighbors.assert_called_once_with( + mock_vertex_ai["endpoint"].find_neighbors.assert_called_once_with( deployed_index_id=vector_store.deployment_index_id, queries=[vectors], num_neighbors=1, filter=[Namespace("user_id", ["test_user"], [])], - return_full_datapoint=True + return_full_datapoint=True, ) assert len(results) == 1 @@ -111,29 +115,27 @@ def test_search_vectors(vector_store, mock_vertex_ai): assert results[0].score == 0.1 assert results[0].payload == {"user_id": "test_user"} + def test_delete(vector_store, mock_vertex_ai): """Test deleting vectors""" vector_id = "test-id" remove_mock = Mock() - with patch.object(GoogleMatchingEngine, 'delete', wraps=vector_store.delete) as delete_spy: - with patch.object(vector_store.index, 'remove_datapoints', remove_mock): + with patch.object(GoogleMatchingEngine, "delete", wraps=vector_store.delete) as delete_spy: + with patch.object(vector_store.index, "remove_datapoints", remove_mock): vector_store.delete(ids=[vector_id]) delete_spy.assert_called_once_with(ids=[vector_id]) remove_mock.assert_called_once_with(datapoint_ids=[vector_id]) + def test_error_handling(vector_store, mock_vertex_ai): """Test error handling during operations""" - mock_vertex_ai['index'].upsert_datapoints.side_effect = exceptions.InvalidArgument("Invalid request") + mock_vertex_ai["index"].upsert_datapoints.side_effect = exceptions.InvalidArgument("Invalid request") with pytest.raises(Exception) as exc_info: - vector_store.insert( - vectors=[[0.1, 0.2, 0.3]], - payloads=[{"name": "test"}], - ids=["test-id"] - ) + vector_store.insert(vectors=[[0.1, 0.2, 0.3]], payloads=[{"name": "test"}], ids=["test-id"]) assert isinstance(exc_info.value, exceptions.InvalidArgument) assert "Invalid request" in str(exc_info.value) diff --git a/tests/vector_stores/test_weaviate.py b/tests/vector_stores/test_weaviate.py index b05d9bdb..96776e43 100644 --- a/tests/vector_stores/test_weaviate.py +++ b/tests/vector_stores/test_weaviate.py @@ -76,15 +76,15 @@ # self.client_mock.batch = MagicMock() # self.client_mock.batch.fixed_size.return_value.__enter__.return_value = MagicMock() - + # self.client_mock.collections.get.return_value.data.insert_many.return_value = { # "results": [{"id": "id1"}, {"id": "id2"}] # } - + # vectors = [[0.1] * 1536, [0.2] * 1536] # payloads = [{"key1": "value1"}, {"key2": "value2"}] # ids = [str(uuid.uuid4()), str(uuid.uuid4())] - + # results = self.weaviate_db.insert(vectors=vectors, payloads=payloads, ids=ids) # def test_get(self): @@ -108,7 +108,7 @@ # result = self.weaviate_db.get(vector_id=valid_uuid) # assert result.id == valid_uuid - + # expected_payload = mock_response.properties.copy() # expected_payload["id"] = valid_uuid @@ -131,10 +131,10 @@ # "metadata": {"distance": 0.2} # } # ] - + # mock_response = MagicMock() # mock_response.objects = [] - + # for obj in mock_objects: # mock_obj = MagicMock() # mock_obj.uuid = obj["uuid"] @@ -142,16 +142,16 @@ # mock_obj.metadata = MagicMock() # mock_obj.metadata.distance = obj["metadata"]["distance"] # mock_response.objects.append(mock_obj) - + # mock_hybrid = MagicMock() # self.client_mock.collections.get.return_value.query.hybrid = mock_hybrid # mock_hybrid.return_value = mock_response - + # vectors = [[0.1] * 1536] # results = self.weaviate_db.search(query="", vectors=vectors, limit=5) - + # mock_hybrid.assert_called_once() - + # self.assertEqual(len(results), 1) # self.assertEqual(results[0].id, "id1") # self.assertEqual(results[0].score, 0.8) @@ -163,28 +163,28 @@ # def test_list(self): # mock_objects = [] - + # mock_obj1 = MagicMock() # mock_obj1.uuid = "id1" # mock_obj1.properties = {"key1": "value1"} # mock_objects.append(mock_obj1) - + # mock_obj2 = MagicMock() # mock_obj2.uuid = "id2" # mock_obj2.properties = {"key2": "value2"} # mock_objects.append(mock_obj2) - + # mock_response = MagicMock() # mock_response.objects = mock_objects - + # mock_fetch = MagicMock() # self.client_mock.collections.get.return_value.query.fetch_objects = mock_fetch # mock_fetch.return_value = mock_response - + # results = self.weaviate_db.list(limit=10) - + # mock_fetch.assert_called_once() - + # # Verify results # self.assertEqual(len(results), 1) # self.assertEqual(len(results[0]), 2)