Add: Json Parsing to solve Hallucination Errors (#3013)

This commit is contained in:
Akshat Jain
2025-06-23 21:50:16 +05:30
committed by GitHub
parent eb24b92227
commit 2bb0653e67
15 changed files with 44 additions and 17 deletions

View File

@@ -11,6 +11,7 @@ import numpy as np
from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase
from mem0.memory.utils import extract_json
class AWSBedrockEmbedding(EmbeddingBase):
@@ -74,7 +75,7 @@ class AWSBedrockEmbedding(EmbeddingBase):
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
response_body = json.loads(extract_json(response.get("body").read()))
if provider == "cohere":
embeddings = response_body.get("embeddings")[0]

View File

@@ -10,6 +10,7 @@ except ImportError:
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
from mem0.memory.utils import extract_json
PROVIDERS = ["ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer"]
@@ -101,7 +102,7 @@ class AWSBedrockLLM(LLMBase):
return processed_response
response_body = response.get("body").read().decode()
response_json = json.loads(response_body)
response_json = json.loads(extract_json(response_body))
return response_json.get("content", [{"text": ""}])[0].get("text", "")
def _prepare_input(

View File

@@ -6,6 +6,7 @@ from openai import AzureOpenAI
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
from mem0.memory.utils import extract_json
class AzureOpenAILLM(LLMBase):
@@ -53,7 +54,7 @@ class AzureOpenAILLM(LLMBase):
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
"arguments": json.loads(extract_json(tool_call.function.arguments)),
}
)

View File

@@ -6,6 +6,7 @@ from openai import OpenAI
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
from mem0.memory.utils import extract_json
class DeepSeekLLM(LLMBase):
@@ -41,7 +42,7 @@ class DeepSeekLLM(LLMBase):
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
"arguments": json.loads(extract_json(tool_call.function.arguments)),
}
)

View File

@@ -9,6 +9,7 @@ except ImportError:
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
from mem0.memory.utils import extract_json
class GroqLLM(LLMBase):
@@ -43,7 +44,7 @@ class GroqLLM(LLMBase):
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
"arguments": json.loads(extract_json(tool_call.function.arguments)),
}
)

View File

@@ -8,6 +8,7 @@ except ImportError:
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
from mem0.memory.utils import extract_json
class LiteLLM(LLMBase):
@@ -39,7 +40,7 @@ class LiteLLM(LLMBase):
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
"arguments": json.loads(extract_json(tool_call.function.arguments)),
}
)

View File

@@ -7,6 +7,7 @@ from openai import OpenAI
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
from mem0.memory.utils import extract_json
class OpenAILLM(LLMBase):
@@ -62,7 +63,7 @@ class OpenAILLM(LLMBase):
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
"arguments": json.loads(extract_json(tool_call.function.arguments)),
}
)

View File

@@ -9,6 +9,7 @@ except ImportError:
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
from mem0.memory.utils import extract_json
class TogetherLLM(LLMBase):
@@ -43,7 +44,7 @@ class TogetherLLM(LLMBase):
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
"arguments": json.loads(extract_json(tool_call.function.arguments)),
}
)

View File

@@ -4,6 +4,7 @@ from typing import Dict, List, Optional
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
from mem0.memory.utils import extract_json
class VllmLLM(LLMBase):
@@ -39,7 +40,7 @@ class VllmLLM(LLMBase):
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append({
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
"arguments": json.loads(extract_json(tool_call.function.arguments)),
})
return processed_response

View File

@@ -46,6 +46,20 @@ def remove_code_blocks(content: str) -> str:
return match.group(1).strip() if match else content.strip()
def extract_json(text):
"""
Extracts JSON content from a string, removing enclosing triple backticks and optional 'json' tag if present.
If no code block is found, returns the text as-is.
"""
text = text.strip()
match = re.search(r"```(?:json)?\s*(.*?)\s*```", text, re.DOTALL)
if match:
json_str = match.group(1)
else:
json_str = text # assume it's raw JSON
return json_str
def get_image_description(image_obj, llm, vision_details):
"""
Get the description of the image

View File

@@ -6,6 +6,7 @@ from typing import List, Optional
from pydantic import BaseModel
from mem0.vector_stores.base import VectorStoreBase
from mem0.memory.utils import extract_json
try:
from azure.core.credentials import AzureKeyCredential
@@ -233,7 +234,7 @@ class AzureAISearch(VectorStoreBase):
results = []
for result in search_results:
payload = json.loads(result["payload"])
payload = json.loads(extract_json(result["payload"]))
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
return results
@@ -288,7 +289,8 @@ class AzureAISearch(VectorStoreBase):
result = self.search_client.get_document(key=vector_id)
except ResourceNotFoundError:
return None
return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"]))
payload = json.loads(extract_json(result["payload"]))
return OutputData(id=result["id"], score=None, payload=payload)
def list_cols(self) -> List[str]:
"""
@@ -335,7 +337,7 @@ class AzureAISearch(VectorStoreBase):
search_results = self.search_client.search(search_text="*", filter=filter_expression, top=limit)
results = []
for result in search_results:
payload = json.loads(result["payload"])
payload = json.loads(extract_json(result["payload"]))
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
return [results]

View File

@@ -12,6 +12,7 @@ from redisvl.query import VectorQuery
from redisvl.query.filter import Tag
from mem0.vector_stores.base import VectorStoreBase
from mem0.memory.utils import extract_json
logger = logging.getLogger(__name__)
@@ -175,7 +176,7 @@ class RedisDB(VectorStoreBase):
else {}
),
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
**{k: v for k, v in json.loads(result["metadata"]).items()},
**{k: v for k, v in json.loads(extract_json(result["metadata"])).items()},
},
)
for result in results
@@ -219,7 +220,7 @@ class RedisDB(VectorStoreBase):
else {}
),
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
**{k: v for k, v in json.loads(result["metadata"]).items()},
**{k: v for k, v in json.loads(extract_json(result["metadata"])).items()},
}
return MemoryResult(id=result["memory_id"], payload=payload)
@@ -286,7 +287,7 @@ class RedisDB(VectorStoreBase):
for field in ["agent_id", "run_id", "user_id"]
if field in result.__dict__
},
**{k: v for k, v in json.loads(result["metadata"]).items()},
**{k: v for k, v in json.loads(extract_json(result["metadata"])).items()},
},
)
for result in results.docs