Improve multimodal functionality (#2297)
This commit is contained in:
@@ -115,7 +115,10 @@ class Memory(MemoryBase):
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
messages = parse_vision_messages(messages)
|
||||
if self.config.llm.config.get("enable_vision"):
|
||||
messages = parse_vision_messages(messages, self.llm, self.config.llm.config.get("vision_details"))
|
||||
else:
|
||||
messages = parse_vision_messages(messages)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import re
|
||||
|
||||
from mem0.configs.prompts import FACT_RETRIEVAL_PROMPT
|
||||
from mem0.llms.openai import OpenAILLM
|
||||
|
||||
|
||||
def get_fact_retrieval_messages(message):
|
||||
@@ -45,13 +45,13 @@ def remove_code_blocks(content: str) -> str:
|
||||
return match.group(1).strip() if match else content.strip()
|
||||
|
||||
|
||||
def get_image_description(image_url):
|
||||
def get_image_description(image_obj, llm, vision_details):
|
||||
"""
|
||||
Get the description of the image
|
||||
"""
|
||||
llm = OpenAILLM()
|
||||
response = llm.generate_response(
|
||||
messages=[
|
||||
|
||||
if isinstance(image_obj, str):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
@@ -59,31 +59,42 @@ def get_image_description(image_url):
|
||||
"type": "text",
|
||||
"text": "A user is providing an image. Provide a high level description of the image and do not include any additional text.",
|
||||
},
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
{"type": "image_url", "image_url": {"url": image_obj, "detail": vision_details}},
|
||||
],
|
||||
},
|
||||
]
|
||||
)
|
||||
else:
|
||||
messages = [image_obj]
|
||||
|
||||
response = llm.generate_response(messages=messages)
|
||||
return response
|
||||
|
||||
|
||||
def parse_vision_messages(messages):
|
||||
def parse_vision_messages(messages, llm=None, vision_details="auto"):
|
||||
"""
|
||||
Parse the vision messages from the messages
|
||||
"""
|
||||
returned_messages = []
|
||||
for msg in messages:
|
||||
if msg["role"] != "system":
|
||||
if not isinstance(msg["content"], str) and msg["content"]["type"] == "image_url":
|
||||
image_url = msg["content"]["image_url"]["url"]
|
||||
try:
|
||||
description = get_image_description(image_url)
|
||||
msg["content"]["text"] = description
|
||||
returned_messages.append({"role": msg["role"], "content": description})
|
||||
except Exception:
|
||||
raise Exception(f"Error while downloading {image_url}.")
|
||||
else:
|
||||
returned_messages.append(msg)
|
||||
else:
|
||||
if msg["role"] == "system":
|
||||
returned_messages.append(msg)
|
||||
continue
|
||||
|
||||
# Handle message content
|
||||
if isinstance(msg["content"], list):
|
||||
# Multiple image URLs in content
|
||||
description = get_image_description(msg, llm, vision_details)
|
||||
returned_messages.append({"role": msg["role"], "content": description})
|
||||
elif isinstance(msg["content"], dict) and msg["content"].get("type") == "image_url":
|
||||
# Single image content
|
||||
image_url = msg["content"]["image_url"]["url"]
|
||||
try:
|
||||
description = get_image_description(image_url, llm, vision_details)
|
||||
returned_messages.append({"role": msg["role"], "content": description})
|
||||
except Exception:
|
||||
raise Exception(f"Error while downloading {image_url}.")
|
||||
else:
|
||||
# Regular text content
|
||||
returned_messages.append(msg)
|
||||
|
||||
return returned_messages
|
||||
|
||||
Reference in New Issue
Block a user