From 2bb0653e679c188d1a8468d95a5d3ec01a4bc309 Mon Sep 17 00:00:00 2001 From: Akshat Jain <125379408+akshat1423@users.noreply.github.com> Date: Mon, 23 Jun 2025 21:50:16 +0530 Subject: [PATCH] Add: Json Parsing to solve Hallucination Errors (#3013) --- evaluation/metrics/llm_judge.py | 5 +++-- mem0/embeddings/aws_bedrock.py | 3 ++- mem0/llms/aws_bedrock.py | 3 ++- mem0/llms/azure_openai.py | 3 ++- mem0/llms/deepseek.py | 3 ++- mem0/llms/groq.py | 3 ++- mem0/llms/litellm.py | 3 ++- mem0/llms/openai.py | 3 ++- mem0/llms/together.py | 3 ++- mem0/llms/utils/__init__.py | 0 mem0/llms/utils/functions.py | 0 mem0/llms/vllm.py | 3 ++- mem0/memory/utils.py | 14 ++++++++++++++ mem0/vector_stores/azure_ai_search.py | 8 +++++--- mem0/vector_stores/redis.py | 7 ++++--- 15 files changed, 44 insertions(+), 17 deletions(-) delete mode 100644 mem0/llms/utils/__init__.py delete mode 100644 mem0/llms/utils/functions.py diff --git a/evaluation/metrics/llm_judge.py b/evaluation/metrics/llm_judge.py index 4d0ec376..8dbae80b 100644 --- a/evaluation/metrics/llm_judge.py +++ b/evaluation/metrics/llm_judge.py @@ -4,6 +4,7 @@ from collections import defaultdict import numpy as np from openai import OpenAI +from mem0.memory.utils import extract_json client = OpenAI() @@ -22,7 +23,7 @@ The generated answer might be much longer, but you should be generous with your For time related questions, the gold answer will be a specific date, month, year, etc. The generated answer might be much longer or use relative time references (like "last Tuesday" or "next month"), but you should be generous with your grading - as long as it refers to the same date or time period as the gold answer, it should be counted as CORRECT. Even if the format differs (e.g., "May 7th" vs "7 May"), consider it CORRECT if it's the same date. -Now it’s time for the real question: +Now it's time for the real question: Question: {question} Gold answer: {gold_answer} Generated answer: {generated_answer} @@ -49,7 +50,7 @@ def evaluate_llm_judge(question, gold_answer, generated_answer): response_format={"type": "json_object"}, temperature=0.0, ) - label = json.loads(response.choices[0].message.content)["label"] + label = json.loads(extract_json(response.choices[0].message.content))["label"] return 1 if label == "CORRECT" else 0 diff --git a/mem0/embeddings/aws_bedrock.py b/mem0/embeddings/aws_bedrock.py index 807764c1..1679aeec 100644 --- a/mem0/embeddings/aws_bedrock.py +++ b/mem0/embeddings/aws_bedrock.py @@ -11,6 +11,7 @@ import numpy as np from mem0.configs.embeddings.base import BaseEmbedderConfig from mem0.embeddings.base import EmbeddingBase +from mem0.memory.utils import extract_json class AWSBedrockEmbedding(EmbeddingBase): @@ -74,7 +75,7 @@ class AWSBedrockEmbedding(EmbeddingBase): contentType="application/json", ) - response_body = json.loads(response.get("body").read()) + response_body = json.loads(extract_json(response.get("body").read())) if provider == "cohere": embeddings = response_body.get("embeddings")[0] diff --git a/mem0/llms/aws_bedrock.py b/mem0/llms/aws_bedrock.py index dde66f30..91a2806f 100644 --- a/mem0/llms/aws_bedrock.py +++ b/mem0/llms/aws_bedrock.py @@ -10,6 +10,7 @@ except ImportError: from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase +from mem0.memory.utils import extract_json PROVIDERS = ["ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer"] @@ -101,7 +102,7 @@ class AWSBedrockLLM(LLMBase): return processed_response response_body = response.get("body").read().decode() - response_json = json.loads(response_body) + response_json = json.loads(extract_json(response_body)) return response_json.get("content", [{"text": ""}])[0].get("text", "") def _prepare_input( diff --git a/mem0/llms/azure_openai.py b/mem0/llms/azure_openai.py index a7f1fdaf..1dcb5f3f 100644 --- a/mem0/llms/azure_openai.py +++ b/mem0/llms/azure_openai.py @@ -6,6 +6,7 @@ from openai import AzureOpenAI from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase +from mem0.memory.utils import extract_json class AzureOpenAILLM(LLMBase): @@ -53,7 +54,7 @@ class AzureOpenAILLM(LLMBase): processed_response["tool_calls"].append( { "name": tool_call.function.name, - "arguments": json.loads(tool_call.function.arguments), + "arguments": json.loads(extract_json(tool_call.function.arguments)), } ) diff --git a/mem0/llms/deepseek.py b/mem0/llms/deepseek.py index 46a805f0..85a0417a 100644 --- a/mem0/llms/deepseek.py +++ b/mem0/llms/deepseek.py @@ -6,6 +6,7 @@ from openai import OpenAI from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase +from mem0.memory.utils import extract_json class DeepSeekLLM(LLMBase): @@ -41,7 +42,7 @@ class DeepSeekLLM(LLMBase): processed_response["tool_calls"].append( { "name": tool_call.function.name, - "arguments": json.loads(tool_call.function.arguments), + "arguments": json.loads(extract_json(tool_call.function.arguments)), } ) diff --git a/mem0/llms/groq.py b/mem0/llms/groq.py index e970a1ee..cc8733d5 100644 --- a/mem0/llms/groq.py +++ b/mem0/llms/groq.py @@ -9,6 +9,7 @@ except ImportError: from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase +from mem0.memory.utils import extract_json class GroqLLM(LLMBase): @@ -43,7 +44,7 @@ class GroqLLM(LLMBase): processed_response["tool_calls"].append( { "name": tool_call.function.name, - "arguments": json.loads(tool_call.function.arguments), + "arguments": json.loads(extract_json(tool_call.function.arguments)), } ) diff --git a/mem0/llms/litellm.py b/mem0/llms/litellm.py index d5896ff8..3a5ef60c 100644 --- a/mem0/llms/litellm.py +++ b/mem0/llms/litellm.py @@ -8,6 +8,7 @@ except ImportError: from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase +from mem0.memory.utils import extract_json class LiteLLM(LLMBase): @@ -39,7 +40,7 @@ class LiteLLM(LLMBase): processed_response["tool_calls"].append( { "name": tool_call.function.name, - "arguments": json.loads(tool_call.function.arguments), + "arguments": json.loads(extract_json(tool_call.function.arguments)), } ) diff --git a/mem0/llms/openai.py b/mem0/llms/openai.py index 7fc3ff4d..8f3cf807 100644 --- a/mem0/llms/openai.py +++ b/mem0/llms/openai.py @@ -7,6 +7,7 @@ from openai import OpenAI from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase +from mem0.memory.utils import extract_json class OpenAILLM(LLMBase): @@ -62,7 +63,7 @@ class OpenAILLM(LLMBase): processed_response["tool_calls"].append( { "name": tool_call.function.name, - "arguments": json.loads(tool_call.function.arguments), + "arguments": json.loads(extract_json(tool_call.function.arguments)), } ) diff --git a/mem0/llms/together.py b/mem0/llms/together.py index 922a30d2..d2af10c1 100644 --- a/mem0/llms/together.py +++ b/mem0/llms/together.py @@ -9,6 +9,7 @@ except ImportError: from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase +from mem0.memory.utils import extract_json class TogetherLLM(LLMBase): @@ -43,7 +44,7 @@ class TogetherLLM(LLMBase): processed_response["tool_calls"].append( { "name": tool_call.function.name, - "arguments": json.loads(tool_call.function.arguments), + "arguments": json.loads(extract_json(tool_call.function.arguments)), } ) diff --git a/mem0/llms/utils/__init__.py b/mem0/llms/utils/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/mem0/llms/utils/functions.py b/mem0/llms/utils/functions.py deleted file mode 100644 index e69de29b..00000000 diff --git a/mem0/llms/vllm.py b/mem0/llms/vllm.py index e522068b..1f0fc822 100644 --- a/mem0/llms/vllm.py +++ b/mem0/llms/vllm.py @@ -4,6 +4,7 @@ from typing import Dict, List, Optional from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase +from mem0.memory.utils import extract_json class VllmLLM(LLMBase): @@ -39,7 +40,7 @@ class VllmLLM(LLMBase): for tool_call in response.choices[0].message.tool_calls: processed_response["tool_calls"].append({ "name": tool_call.function.name, - "arguments": json.loads(tool_call.function.arguments), + "arguments": json.loads(extract_json(tool_call.function.arguments)), }) return processed_response diff --git a/mem0/memory/utils.py b/mem0/memory/utils.py index b083508d..9018d546 100644 --- a/mem0/memory/utils.py +++ b/mem0/memory/utils.py @@ -46,6 +46,20 @@ def remove_code_blocks(content: str) -> str: return match.group(1).strip() if match else content.strip() +def extract_json(text): + """ + Extracts JSON content from a string, removing enclosing triple backticks and optional 'json' tag if present. + If no code block is found, returns the text as-is. + """ + text = text.strip() + match = re.search(r"```(?:json)?\s*(.*?)\s*```", text, re.DOTALL) + if match: + json_str = match.group(1) + else: + json_str = text # assume it's raw JSON + return json_str + + def get_image_description(image_obj, llm, vision_details): """ Get the description of the image diff --git a/mem0/vector_stores/azure_ai_search.py b/mem0/vector_stores/azure_ai_search.py index 6acd2e40..7e06cff7 100644 --- a/mem0/vector_stores/azure_ai_search.py +++ b/mem0/vector_stores/azure_ai_search.py @@ -6,6 +6,7 @@ from typing import List, Optional from pydantic import BaseModel from mem0.vector_stores.base import VectorStoreBase +from mem0.memory.utils import extract_json try: from azure.core.credentials import AzureKeyCredential @@ -233,7 +234,7 @@ class AzureAISearch(VectorStoreBase): results = [] for result in search_results: - payload = json.loads(result["payload"]) + payload = json.loads(extract_json(result["payload"])) results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload)) return results @@ -288,7 +289,8 @@ class AzureAISearch(VectorStoreBase): result = self.search_client.get_document(key=vector_id) except ResourceNotFoundError: return None - return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"])) + payload = json.loads(extract_json(result["payload"])) + return OutputData(id=result["id"], score=None, payload=payload) def list_cols(self) -> List[str]: """ @@ -335,7 +337,7 @@ class AzureAISearch(VectorStoreBase): search_results = self.search_client.search(search_text="*", filter=filter_expression, top=limit) results = [] for result in search_results: - payload = json.loads(result["payload"]) + payload = json.loads(extract_json(result["payload"])) results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload)) return [results] diff --git a/mem0/vector_stores/redis.py b/mem0/vector_stores/redis.py index d2819975..8e85055f 100644 --- a/mem0/vector_stores/redis.py +++ b/mem0/vector_stores/redis.py @@ -12,6 +12,7 @@ from redisvl.query import VectorQuery from redisvl.query.filter import Tag from mem0.vector_stores.base import VectorStoreBase +from mem0.memory.utils import extract_json logger = logging.getLogger(__name__) @@ -175,7 +176,7 @@ class RedisDB(VectorStoreBase): else {} ), **{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result}, - **{k: v for k, v in json.loads(result["metadata"]).items()}, + **{k: v for k, v in json.loads(extract_json(result["metadata"])).items()}, }, ) for result in results @@ -219,7 +220,7 @@ class RedisDB(VectorStoreBase): else {} ), **{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result}, - **{k: v for k, v in json.loads(result["metadata"]).items()}, + **{k: v for k, v in json.loads(extract_json(result["metadata"])).items()}, } return MemoryResult(id=result["memory_id"], payload=payload) @@ -286,7 +287,7 @@ class RedisDB(VectorStoreBase): for field in ["agent_id", "run_id", "user_id"] if field in result.__dict__ }, - **{k: v for k, v in json.loads(result["metadata"]).items()}, + **{k: v for k, v in json.loads(extract_json(result["metadata"])).items()}, }, ) for result in results.docs