Add: Json Parsing to solve Hallucination Errors (#3013)
This commit is contained in:
@@ -11,6 +11,7 @@ import numpy as np
|
||||
|
||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||
from mem0.embeddings.base import EmbeddingBase
|
||||
from mem0.memory.utils import extract_json
|
||||
|
||||
|
||||
class AWSBedrockEmbedding(EmbeddingBase):
|
||||
@@ -74,7 +75,7 @@ class AWSBedrockEmbedding(EmbeddingBase):
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
response_body = json.loads(response.get("body").read())
|
||||
response_body = json.loads(extract_json(response.get("body").read()))
|
||||
|
||||
if provider == "cohere":
|
||||
embeddings = response_body.get("embeddings")[0]
|
||||
|
||||
@@ -10,6 +10,7 @@ except ImportError:
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.memory.utils import extract_json
|
||||
|
||||
PROVIDERS = ["ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer"]
|
||||
|
||||
@@ -101,7 +102,7 @@ class AWSBedrockLLM(LLMBase):
|
||||
return processed_response
|
||||
|
||||
response_body = response.get("body").read().decode()
|
||||
response_json = json.loads(response_body)
|
||||
response_json = json.loads(extract_json(response_body))
|
||||
return response_json.get("content", [{"text": ""}])[0].get("text", "")
|
||||
|
||||
def _prepare_input(
|
||||
|
||||
@@ -6,6 +6,7 @@ from openai import AzureOpenAI
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.memory.utils import extract_json
|
||||
|
||||
|
||||
class AzureOpenAILLM(LLMBase):
|
||||
@@ -53,7 +54,7 @@ class AzureOpenAILLM(LLMBase):
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments),
|
||||
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from openai import OpenAI
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.memory.utils import extract_json
|
||||
|
||||
|
||||
class DeepSeekLLM(LLMBase):
|
||||
@@ -41,7 +42,7 @@ class DeepSeekLLM(LLMBase):
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments),
|
||||
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ except ImportError:
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.memory.utils import extract_json
|
||||
|
||||
|
||||
class GroqLLM(LLMBase):
|
||||
@@ -43,7 +44,7 @@ class GroqLLM(LLMBase):
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments),
|
||||
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ except ImportError:
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.memory.utils import extract_json
|
||||
|
||||
|
||||
class LiteLLM(LLMBase):
|
||||
@@ -39,7 +40,7 @@ class LiteLLM(LLMBase):
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments),
|
||||
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from openai import OpenAI
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.memory.utils import extract_json
|
||||
|
||||
|
||||
class OpenAILLM(LLMBase):
|
||||
@@ -62,7 +63,7 @@ class OpenAILLM(LLMBase):
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments),
|
||||
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ except ImportError:
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.memory.utils import extract_json
|
||||
|
||||
|
||||
class TogetherLLM(LLMBase):
|
||||
@@ -43,7 +44,7 @@ class TogetherLLM(LLMBase):
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments),
|
||||
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Dict, List, Optional
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.memory.utils import extract_json
|
||||
|
||||
|
||||
class VllmLLM(LLMBase):
|
||||
@@ -39,7 +40,7 @@ class VllmLLM(LLMBase):
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
processed_response["tool_calls"].append({
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments),
|
||||
"arguments": json.loads(extract_json(tool_call.function.arguments)),
|
||||
})
|
||||
|
||||
return processed_response
|
||||
|
||||
@@ -46,6 +46,20 @@ def remove_code_blocks(content: str) -> str:
|
||||
return match.group(1).strip() if match else content.strip()
|
||||
|
||||
|
||||
def extract_json(text):
|
||||
"""
|
||||
Extracts JSON content from a string, removing enclosing triple backticks and optional 'json' tag if present.
|
||||
If no code block is found, returns the text as-is.
|
||||
"""
|
||||
text = text.strip()
|
||||
match = re.search(r"```(?:json)?\s*(.*?)\s*```", text, re.DOTALL)
|
||||
if match:
|
||||
json_str = match.group(1)
|
||||
else:
|
||||
json_str = text # assume it's raw JSON
|
||||
return json_str
|
||||
|
||||
|
||||
def get_image_description(image_obj, llm, vision_details):
|
||||
"""
|
||||
Get the description of the image
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
from mem0.memory.utils import extract_json
|
||||
|
||||
try:
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
@@ -233,7 +234,7 @@ class AzureAISearch(VectorStoreBase):
|
||||
|
||||
results = []
|
||||
for result in search_results:
|
||||
payload = json.loads(result["payload"])
|
||||
payload = json.loads(extract_json(result["payload"]))
|
||||
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
||||
return results
|
||||
|
||||
@@ -288,7 +289,8 @@ class AzureAISearch(VectorStoreBase):
|
||||
result = self.search_client.get_document(key=vector_id)
|
||||
except ResourceNotFoundError:
|
||||
return None
|
||||
return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"]))
|
||||
payload = json.loads(extract_json(result["payload"]))
|
||||
return OutputData(id=result["id"], score=None, payload=payload)
|
||||
|
||||
def list_cols(self) -> List[str]:
|
||||
"""
|
||||
@@ -335,7 +337,7 @@ class AzureAISearch(VectorStoreBase):
|
||||
search_results = self.search_client.search(search_text="*", filter=filter_expression, top=limit)
|
||||
results = []
|
||||
for result in search_results:
|
||||
payload = json.loads(result["payload"])
|
||||
payload = json.loads(extract_json(result["payload"]))
|
||||
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
||||
return [results]
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from redisvl.query import VectorQuery
|
||||
from redisvl.query.filter import Tag
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
from mem0.memory.utils import extract_json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -175,7 +176,7 @@ class RedisDB(VectorStoreBase):
|
||||
else {}
|
||||
),
|
||||
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
|
||||
**{k: v for k, v in json.loads(result["metadata"]).items()},
|
||||
**{k: v for k, v in json.loads(extract_json(result["metadata"])).items()},
|
||||
},
|
||||
)
|
||||
for result in results
|
||||
@@ -219,7 +220,7 @@ class RedisDB(VectorStoreBase):
|
||||
else {}
|
||||
),
|
||||
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
|
||||
**{k: v for k, v in json.loads(result["metadata"]).items()},
|
||||
**{k: v for k, v in json.loads(extract_json(result["metadata"])).items()},
|
||||
}
|
||||
|
||||
return MemoryResult(id=result["memory_id"], payload=payload)
|
||||
@@ -286,7 +287,7 @@ class RedisDB(VectorStoreBase):
|
||||
for field in ["agent_id", "run_id", "user_id"]
|
||||
if field in result.__dict__
|
||||
},
|
||||
**{k: v for k, v in json.loads(result["metadata"]).items()},
|
||||
**{k: v for k, v in json.loads(extract_json(result["metadata"])).items()},
|
||||
},
|
||||
)
|
||||
for result in results.docs
|
||||
|
||||
Reference in New Issue
Block a user