Add: Json Parsing to solve Hallucination Errors (#3013)
This commit is contained in:
@@ -4,6 +4,7 @@ from collections import defaultdict
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
from mem0.memory.utils import extract_json
|
||||||
|
|
||||||
client = OpenAI()
|
client = OpenAI()
|
||||||
|
|
||||||
@@ -22,7 +23,7 @@ The generated answer might be much longer, but you should be generous with your
|
|||||||
|
|
||||||
For time related questions, the gold answer will be a specific date, month, year, etc. The generated answer might be much longer or use relative time references (like "last Tuesday" or "next month"), but you should be generous with your grading - as long as it refers to the same date or time period as the gold answer, it should be counted as CORRECT. Even if the format differs (e.g., "May 7th" vs "7 May"), consider it CORRECT if it's the same date.
|
For time related questions, the gold answer will be a specific date, month, year, etc. The generated answer might be much longer or use relative time references (like "last Tuesday" or "next month"), but you should be generous with your grading - as long as it refers to the same date or time period as the gold answer, it should be counted as CORRECT. Even if the format differs (e.g., "May 7th" vs "7 May"), consider it CORRECT if it's the same date.
|
||||||
|
|
||||||
Now it’s time for the real question:
|
Now it's time for the real question:
|
||||||
Question: {question}
|
Question: {question}
|
||||||
Gold answer: {gold_answer}
|
Gold answer: {gold_answer}
|
||||||
Generated answer: {generated_answer}
|
Generated answer: {generated_answer}
|
||||||
@@ -49,7 +50,7 @@ def evaluate_llm_judge(question, gold_answer, generated_answer):
|
|||||||
response_format={"type": "json_object"},
|
response_format={"type": "json_object"},
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
)
|
)
|
||||||
label = json.loads(response.choices[0].message.content)["label"]
|
label = json.loads(extract_json(response.choices[0].message.content))["label"]
|
||||||
return 1 if label == "CORRECT" else 0
|
return 1 if label == "CORRECT" else 0
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import numpy as np
|
|||||||
|
|
||||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||||
from mem0.embeddings.base import EmbeddingBase
|
from mem0.embeddings.base import EmbeddingBase
|
||||||
|
from mem0.memory.utils import extract_json
|
||||||
|
|
||||||
|
|
||||||
class AWSBedrockEmbedding(EmbeddingBase):
|
class AWSBedrockEmbedding(EmbeddingBase):
|
||||||
@@ -74,7 +75,7 @@ class AWSBedrockEmbedding(EmbeddingBase):
|
|||||||
contentType="application/json",
|
contentType="application/json",
|
||||||
)
|
)
|
||||||
|
|
||||||
response_body = json.loads(response.get("body").read())
|
response_body = json.loads(extract_json(response.get("body").read()))
|
||||||
|
|
||||||
if provider == "cohere":
|
if provider == "cohere":
|
||||||
embeddings = response_body.get("embeddings")[0]
|
embeddings = response_body.get("embeddings")[0]
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ except ImportError:
|
|||||||
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
from mem0.llms.base import LLMBase
|
from mem0.llms.base import LLMBase
|
||||||
|
from mem0.memory.utils import extract_json
|
||||||
|
|
||||||
PROVIDERS = ["ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer"]
|
PROVIDERS = ["ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer"]
|
||||||
|
|
||||||
@@ -101,7 +102,7 @@ class AWSBedrockLLM(LLMBase):
|
|||||||
return processed_response
|
return processed_response
|
||||||
|
|
||||||
response_body = response.get("body").read().decode()
|
response_body = response.get("body").read().decode()
|
||||||
response_json = json.loads(response_body)
|
response_json = json.loads(extract_json(response_body))
|
||||||
return response_json.get("content", [{"text": ""}])[0].get("text", "")
|
return response_json.get("content", [{"text": ""}])[0].get("text", "")
|
||||||
|
|
||||||
def _prepare_input(
|
def _prepare_input(
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from openai import AzureOpenAI
|
|||||||
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
from mem0.llms.base import LLMBase
|
from mem0.llms.base import LLMBase
|
||||||
|
from mem0.memory.utils import extract_json
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAILLM(LLMBase):
|
class AzureOpenAILLM(LLMBase):
|
||||||
@@ -53,7 +54,7 @@ class AzureOpenAILLM(LLMBase):
|
|||||||
processed_response["tool_calls"].append(
|
processed_response["tool_calls"].append(
|
||||||
{
|
{
|
||||||
"name": tool_call.function.name,
|
"name": tool_call.function.name,
|
||||||
"arguments": json.loads(tool_call.function.arguments),
|
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from openai import OpenAI
|
|||||||
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
from mem0.llms.base import LLMBase
|
from mem0.llms.base import LLMBase
|
||||||
|
from mem0.memory.utils import extract_json
|
||||||
|
|
||||||
|
|
||||||
class DeepSeekLLM(LLMBase):
|
class DeepSeekLLM(LLMBase):
|
||||||
@@ -41,7 +42,7 @@ class DeepSeekLLM(LLMBase):
|
|||||||
processed_response["tool_calls"].append(
|
processed_response["tool_calls"].append(
|
||||||
{
|
{
|
||||||
"name": tool_call.function.name,
|
"name": tool_call.function.name,
|
||||||
"arguments": json.loads(tool_call.function.arguments),
|
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ except ImportError:
|
|||||||
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
from mem0.llms.base import LLMBase
|
from mem0.llms.base import LLMBase
|
||||||
|
from mem0.memory.utils import extract_json
|
||||||
|
|
||||||
|
|
||||||
class GroqLLM(LLMBase):
|
class GroqLLM(LLMBase):
|
||||||
@@ -43,7 +44,7 @@ class GroqLLM(LLMBase):
|
|||||||
processed_response["tool_calls"].append(
|
processed_response["tool_calls"].append(
|
||||||
{
|
{
|
||||||
"name": tool_call.function.name,
|
"name": tool_call.function.name,
|
||||||
"arguments": json.loads(tool_call.function.arguments),
|
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ except ImportError:
|
|||||||
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
from mem0.llms.base import LLMBase
|
from mem0.llms.base import LLMBase
|
||||||
|
from mem0.memory.utils import extract_json
|
||||||
|
|
||||||
|
|
||||||
class LiteLLM(LLMBase):
|
class LiteLLM(LLMBase):
|
||||||
@@ -39,7 +40,7 @@ class LiteLLM(LLMBase):
|
|||||||
processed_response["tool_calls"].append(
|
processed_response["tool_calls"].append(
|
||||||
{
|
{
|
||||||
"name": tool_call.function.name,
|
"name": tool_call.function.name,
|
||||||
"arguments": json.loads(tool_call.function.arguments),
|
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from openai import OpenAI
|
|||||||
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
from mem0.llms.base import LLMBase
|
from mem0.llms.base import LLMBase
|
||||||
|
from mem0.memory.utils import extract_json
|
||||||
|
|
||||||
|
|
||||||
class OpenAILLM(LLMBase):
|
class OpenAILLM(LLMBase):
|
||||||
@@ -62,7 +63,7 @@ class OpenAILLM(LLMBase):
|
|||||||
processed_response["tool_calls"].append(
|
processed_response["tool_calls"].append(
|
||||||
{
|
{
|
||||||
"name": tool_call.function.name,
|
"name": tool_call.function.name,
|
||||||
"arguments": json.loads(tool_call.function.arguments),
|
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ except ImportError:
|
|||||||
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
from mem0.llms.base import LLMBase
|
from mem0.llms.base import LLMBase
|
||||||
|
from mem0.memory.utils import extract_json
|
||||||
|
|
||||||
|
|
||||||
class TogetherLLM(LLMBase):
|
class TogetherLLM(LLMBase):
|
||||||
@@ -43,7 +44,7 @@ class TogetherLLM(LLMBase):
|
|||||||
processed_response["tool_calls"].append(
|
processed_response["tool_calls"].append(
|
||||||
{
|
{
|
||||||
"name": tool_call.function.name,
|
"name": tool_call.function.name,
|
||||||
"arguments": json.loads(tool_call.function.arguments),
|
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from typing import Dict, List, Optional
|
|||||||
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
from mem0.llms.base import LLMBase
|
from mem0.llms.base import LLMBase
|
||||||
|
from mem0.memory.utils import extract_json
|
||||||
|
|
||||||
|
|
||||||
class VllmLLM(LLMBase):
|
class VllmLLM(LLMBase):
|
||||||
@@ -39,7 +40,7 @@ class VllmLLM(LLMBase):
|
|||||||
for tool_call in response.choices[0].message.tool_calls:
|
for tool_call in response.choices[0].message.tool_calls:
|
||||||
processed_response["tool_calls"].append({
|
processed_response["tool_calls"].append({
|
||||||
"name": tool_call.function.name,
|
"name": tool_call.function.name,
|
||||||
"arguments": json.loads(tool_call.function.arguments),
|
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
||||||
})
|
})
|
||||||
|
|
||||||
return processed_response
|
return processed_response
|
||||||
|
|||||||
@@ -46,6 +46,20 @@ def remove_code_blocks(content: str) -> str:
|
|||||||
return match.group(1).strip() if match else content.strip()
|
return match.group(1).strip() if match else content.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def extract_json(text):
|
||||||
|
"""
|
||||||
|
Extracts JSON content from a string, removing enclosing triple backticks and optional 'json' tag if present.
|
||||||
|
If no code block is found, returns the text as-is.
|
||||||
|
"""
|
||||||
|
text = text.strip()
|
||||||
|
match = re.search(r"```(?:json)?\s*(.*?)\s*```", text, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
json_str = match.group(1)
|
||||||
|
else:
|
||||||
|
json_str = text # assume it's raw JSON
|
||||||
|
return json_str
|
||||||
|
|
||||||
|
|
||||||
def get_image_description(image_obj, llm, vision_details):
|
def get_image_description(image_obj, llm, vision_details):
|
||||||
"""
|
"""
|
||||||
Get the description of the image
|
Get the description of the image
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import List, Optional
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from mem0.vector_stores.base import VectorStoreBase
|
from mem0.vector_stores.base import VectorStoreBase
|
||||||
|
from mem0.memory.utils import extract_json
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from azure.core.credentials import AzureKeyCredential
|
from azure.core.credentials import AzureKeyCredential
|
||||||
@@ -233,7 +234,7 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
|
|
||||||
results = []
|
results = []
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
payload = json.loads(result["payload"])
|
payload = json.loads(extract_json(result["payload"]))
|
||||||
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@@ -288,7 +289,8 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
result = self.search_client.get_document(key=vector_id)
|
result = self.search_client.get_document(key=vector_id)
|
||||||
except ResourceNotFoundError:
|
except ResourceNotFoundError:
|
||||||
return None
|
return None
|
||||||
return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"]))
|
payload = json.loads(extract_json(result["payload"]))
|
||||||
|
return OutputData(id=result["id"], score=None, payload=payload)
|
||||||
|
|
||||||
def list_cols(self) -> List[str]:
|
def list_cols(self) -> List[str]:
|
||||||
"""
|
"""
|
||||||
@@ -335,7 +337,7 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
search_results = self.search_client.search(search_text="*", filter=filter_expression, top=limit)
|
search_results = self.search_client.search(search_text="*", filter=filter_expression, top=limit)
|
||||||
results = []
|
results = []
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
payload = json.loads(result["payload"])
|
payload = json.loads(extract_json(result["payload"]))
|
||||||
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
||||||
return [results]
|
return [results]
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from redisvl.query import VectorQuery
|
|||||||
from redisvl.query.filter import Tag
|
from redisvl.query.filter import Tag
|
||||||
|
|
||||||
from mem0.vector_stores.base import VectorStoreBase
|
from mem0.vector_stores.base import VectorStoreBase
|
||||||
|
from mem0.memory.utils import extract_json
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -175,7 +176,7 @@ class RedisDB(VectorStoreBase):
|
|||||||
else {}
|
else {}
|
||||||
),
|
),
|
||||||
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
|
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
|
||||||
**{k: v for k, v in json.loads(result["metadata"]).items()},
|
**{k: v for k, v in json.loads(extract_json(result["metadata"])).items()},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
for result in results
|
for result in results
|
||||||
@@ -219,7 +220,7 @@ class RedisDB(VectorStoreBase):
|
|||||||
else {}
|
else {}
|
||||||
),
|
),
|
||||||
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
|
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
|
||||||
**{k: v for k, v in json.loads(result["metadata"]).items()},
|
**{k: v for k, v in json.loads(extract_json(result["metadata"])).items()},
|
||||||
}
|
}
|
||||||
|
|
||||||
return MemoryResult(id=result["memory_id"], payload=payload)
|
return MemoryResult(id=result["memory_id"], payload=payload)
|
||||||
@@ -286,7 +287,7 @@ class RedisDB(VectorStoreBase):
|
|||||||
for field in ["agent_id", "run_id", "user_id"]
|
for field in ["agent_id", "run_id", "user_id"]
|
||||||
if field in result.__dict__
|
if field in result.__dict__
|
||||||
},
|
},
|
||||||
**{k: v for k, v in json.loads(result["metadata"]).items()},
|
**{k: v for k, v in json.loads(extract_json(result["metadata"])).items()},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
for result in results.docs
|
for result in results.docs
|
||||||
|
|||||||
Reference in New Issue
Block a user