Add: Json Parsing to solve Hallucination Errors (#3013)

This commit is contained in:
Akshat Jain
2025-06-23 21:50:16 +05:30
committed by GitHub
parent eb24b92227
commit 2bb0653e67
15 changed files with 44 additions and 17 deletions

View File

@@ -4,6 +4,7 @@ from collections import defaultdict
import numpy as np import numpy as np
from openai import OpenAI from openai import OpenAI
from mem0.memory.utils import extract_json
client = OpenAI() client = OpenAI()
@@ -22,7 +23,7 @@ The generated answer might be much longer, but you should be generous with your
For time related questions, the gold answer will be a specific date, month, year, etc. The generated answer might be much longer or use relative time references (like "last Tuesday" or "next month"), but you should be generous with your grading - as long as it refers to the same date or time period as the gold answer, it should be counted as CORRECT. Even if the format differs (e.g., "May 7th" vs "7 May"), consider it CORRECT if it's the same date. For time related questions, the gold answer will be a specific date, month, year, etc. The generated answer might be much longer or use relative time references (like "last Tuesday" or "next month"), but you should be generous with your grading - as long as it refers to the same date or time period as the gold answer, it should be counted as CORRECT. Even if the format differs (e.g., "May 7th" vs "7 May"), consider it CORRECT if it's the same date.
Now its time for the real question: Now it's time for the real question:
Question: {question} Question: {question}
Gold answer: {gold_answer} Gold answer: {gold_answer}
Generated answer: {generated_answer} Generated answer: {generated_answer}
@@ -49,7 +50,7 @@ def evaluate_llm_judge(question, gold_answer, generated_answer):
response_format={"type": "json_object"}, response_format={"type": "json_object"},
temperature=0.0, temperature=0.0,
) )
label = json.loads(response.choices[0].message.content)["label"] label = json.loads(extract_json(response.choices[0].message.content))["label"]
return 1 if label == "CORRECT" else 0 return 1 if label == "CORRECT" else 0

View File

@@ -11,6 +11,7 @@ import numpy as np
from mem0.configs.embeddings.base import BaseEmbedderConfig from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase from mem0.embeddings.base import EmbeddingBase
from mem0.memory.utils import extract_json
class AWSBedrockEmbedding(EmbeddingBase): class AWSBedrockEmbedding(EmbeddingBase):
@@ -74,7 +75,7 @@ class AWSBedrockEmbedding(EmbeddingBase):
contentType="application/json", contentType="application/json",
) )
response_body = json.loads(response.get("body").read()) response_body = json.loads(extract_json(response.get("body").read()))
if provider == "cohere": if provider == "cohere":
embeddings = response_body.get("embeddings")[0] embeddings = response_body.get("embeddings")[0]

View File

@@ -10,6 +10,7 @@ except ImportError:
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.memory.utils import extract_json
PROVIDERS = ["ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer"] PROVIDERS = ["ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer"]
@@ -101,7 +102,7 @@ class AWSBedrockLLM(LLMBase):
return processed_response return processed_response
response_body = response.get("body").read().decode() response_body = response.get("body").read().decode()
response_json = json.loads(response_body) response_json = json.loads(extract_json(response_body))
return response_json.get("content", [{"text": ""}])[0].get("text", "") return response_json.get("content", [{"text": ""}])[0].get("text", "")
def _prepare_input( def _prepare_input(

View File

@@ -6,6 +6,7 @@ from openai import AzureOpenAI
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.memory.utils import extract_json
class AzureOpenAILLM(LLMBase): class AzureOpenAILLM(LLMBase):
@@ -53,7 +54,7 @@ class AzureOpenAILLM(LLMBase):
processed_response["tool_calls"].append( processed_response["tool_calls"].append(
{ {
"name": tool_call.function.name, "name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments), "arguments": json.loads(extract_json(tool_call.function.arguments)),
} }
) )

View File

@@ -6,6 +6,7 @@ from openai import OpenAI
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.memory.utils import extract_json
class DeepSeekLLM(LLMBase): class DeepSeekLLM(LLMBase):
@@ -41,7 +42,7 @@ class DeepSeekLLM(LLMBase):
processed_response["tool_calls"].append( processed_response["tool_calls"].append(
{ {
"name": tool_call.function.name, "name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments), "arguments": json.loads(extract_json(tool_call.function.arguments)),
} }
) )

View File

@@ -9,6 +9,7 @@ except ImportError:
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.memory.utils import extract_json
class GroqLLM(LLMBase): class GroqLLM(LLMBase):
@@ -43,7 +44,7 @@ class GroqLLM(LLMBase):
processed_response["tool_calls"].append( processed_response["tool_calls"].append(
{ {
"name": tool_call.function.name, "name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments), "arguments": json.loads(extract_json(tool_call.function.arguments)),
} }
) )

View File

@@ -8,6 +8,7 @@ except ImportError:
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.memory.utils import extract_json
class LiteLLM(LLMBase): class LiteLLM(LLMBase):
@@ -39,7 +40,7 @@ class LiteLLM(LLMBase):
processed_response["tool_calls"].append( processed_response["tool_calls"].append(
{ {
"name": tool_call.function.name, "name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments), "arguments": json.loads(extract_json(tool_call.function.arguments)),
} }
) )

View File

@@ -7,6 +7,7 @@ from openai import OpenAI
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.memory.utils import extract_json
class OpenAILLM(LLMBase): class OpenAILLM(LLMBase):
@@ -62,7 +63,7 @@ class OpenAILLM(LLMBase):
processed_response["tool_calls"].append( processed_response["tool_calls"].append(
{ {
"name": tool_call.function.name, "name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments), "arguments": json.loads(extract_json(tool_call.function.arguments)),
} }
) )

View File

@@ -9,6 +9,7 @@ except ImportError:
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.memory.utils import extract_json
class TogetherLLM(LLMBase): class TogetherLLM(LLMBase):
@@ -43,7 +44,7 @@ class TogetherLLM(LLMBase):
processed_response["tool_calls"].append( processed_response["tool_calls"].append(
{ {
"name": tool_call.function.name, "name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments), "arguments": json.loads(extract_json(tool_call.function.arguments)),
} }
) )

View File

@@ -4,6 +4,7 @@ from typing import Dict, List, Optional
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.memory.utils import extract_json
class VllmLLM(LLMBase): class VllmLLM(LLMBase):
@@ -39,7 +40,7 @@ class VllmLLM(LLMBase):
for tool_call in response.choices[0].message.tool_calls: for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append({ processed_response["tool_calls"].append({
"name": tool_call.function.name, "name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments), "arguments": json.loads(extract_json(tool_call.function.arguments)),
}) })
return processed_response return processed_response

View File

@@ -46,6 +46,20 @@ def remove_code_blocks(content: str) -> str:
return match.group(1).strip() if match else content.strip() return match.group(1).strip() if match else content.strip()
def extract_json(text):
"""
Extracts JSON content from a string, removing enclosing triple backticks and optional 'json' tag if present.
If no code block is found, returns the text as-is.
"""
text = text.strip()
match = re.search(r"```(?:json)?\s*(.*?)\s*```", text, re.DOTALL)
if match:
json_str = match.group(1)
else:
json_str = text # assume it's raw JSON
return json_str
def get_image_description(image_obj, llm, vision_details): def get_image_description(image_obj, llm, vision_details):
""" """
Get the description of the image Get the description of the image

View File

@@ -6,6 +6,7 @@ from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel
from mem0.vector_stores.base import VectorStoreBase from mem0.vector_stores.base import VectorStoreBase
from mem0.memory.utils import extract_json
try: try:
from azure.core.credentials import AzureKeyCredential from azure.core.credentials import AzureKeyCredential
@@ -233,7 +234,7 @@ class AzureAISearch(VectorStoreBase):
results = [] results = []
for result in search_results: for result in search_results:
payload = json.loads(result["payload"]) payload = json.loads(extract_json(result["payload"]))
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload)) results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
return results return results
@@ -288,7 +289,8 @@ class AzureAISearch(VectorStoreBase):
result = self.search_client.get_document(key=vector_id) result = self.search_client.get_document(key=vector_id)
except ResourceNotFoundError: except ResourceNotFoundError:
return None return None
return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"])) payload = json.loads(extract_json(result["payload"]))
return OutputData(id=result["id"], score=None, payload=payload)
def list_cols(self) -> List[str]: def list_cols(self) -> List[str]:
""" """
@@ -335,7 +337,7 @@ class AzureAISearch(VectorStoreBase):
search_results = self.search_client.search(search_text="*", filter=filter_expression, top=limit) search_results = self.search_client.search(search_text="*", filter=filter_expression, top=limit)
results = [] results = []
for result in search_results: for result in search_results:
payload = json.loads(result["payload"]) payload = json.loads(extract_json(result["payload"]))
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload)) results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
return [results] return [results]

View File

@@ -12,6 +12,7 @@ from redisvl.query import VectorQuery
from redisvl.query.filter import Tag from redisvl.query.filter import Tag
from mem0.vector_stores.base import VectorStoreBase from mem0.vector_stores.base import VectorStoreBase
from mem0.memory.utils import extract_json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -175,7 +176,7 @@ class RedisDB(VectorStoreBase):
else {} else {}
), ),
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result}, **{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
**{k: v for k, v in json.loads(result["metadata"]).items()}, **{k: v for k, v in json.loads(extract_json(result["metadata"])).items()},
}, },
) )
for result in results for result in results
@@ -219,7 +220,7 @@ class RedisDB(VectorStoreBase):
else {} else {}
), ),
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result}, **{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
**{k: v for k, v in json.loads(result["metadata"]).items()}, **{k: v for k, v in json.loads(extract_json(result["metadata"])).items()},
} }
return MemoryResult(id=result["memory_id"], payload=payload) return MemoryResult(id=result["memory_id"], payload=payload)
@@ -286,7 +287,7 @@ class RedisDB(VectorStoreBase):
for field in ["agent_id", "run_id", "user_id"] for field in ["agent_id", "run_id", "user_id"]
if field in result.__dict__ if field in result.__dict__
}, },
**{k: v for k, v in json.loads(result["metadata"]).items()}, **{k: v for k, v in json.loads(extract_json(result["metadata"])).items()},
}, },
) )
for result in results.docs for result in results.docs