Added support of vision input

This commit is contained in:
Prateek Chhikara
2025-02-18 11:47:13 -08:00
committed by GitHub
parent cbee71a63e
commit cc9acb7493
9 changed files with 973 additions and 427 deletions

View File

@@ -9,7 +9,7 @@ from typing import Any, Dict
import pytz
from pydantic import ValidationError
from mem0.memory.utils import parse_vision_messages
from mem0.configs.base import MemoryConfig, MemoryItem
from mem0.configs.prompts import get_update_memory_messages
from mem0.memory.base import MemoryBase
@@ -114,6 +114,8 @@ class Memory(MemoryBase):
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
messages = parse_vision_messages(messages)
with concurrent.futures.ThreadPoolExecutor() as executor:
future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters)
future2 = executor.submit(self._add_to_graph, messages, filters)
@@ -143,7 +145,7 @@ class Memory(MemoryBase):
if self.custom_prompt:
system_prompt = self.custom_prompt
user_prompt = f"Input: {parsed_messages}"
user_prompt = f"Input:\n{parsed_messages}"
else:
system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages)

View File

@@ -1,10 +1,10 @@
import re
from mem0.configs.prompts import FACT_RETRIEVAL_PROMPT
from mem0.llms.openai import OpenAILLM
def get_fact_retrieval_messages(message):
return FACT_RETRIEVAL_PROMPT, f"Input: {message}"
return FACT_RETRIEVAL_PROMPT, f"Input:\n{message}"
def parse_messages(messages):
@@ -43,3 +43,45 @@ def remove_code_blocks(content: str) -> str:
pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$"
match = re.match(pattern, content.strip())
return match.group(1).strip() if match else content.strip()
def get_image_description(image_url):
"""
Get the description of the image
"""
llm = OpenAILLM()
response = llm.generate_response(
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Provide a description of the image and do not include any additional text."},
{"type": "image_url", "image_url": {"url": image_url}}
],
},
],
max_tokens=100,
)
return response
def parse_vision_messages(messages):
"""
Parse the vision messages from the messages
"""
returned_messages = []
for msg in messages:
if msg["role"] != "system":
if not isinstance(msg["content"], str) and msg["content"]["type"] == "image_url":
image_url = msg["content"]["image_url"]["url"]
try:
description = get_image_description(image_url)
msg["content"]["text"] = description
returned_messages.append({"role": msg["role"], "content": description})
except Exception:
raise Exception(f"Error while downloading {image_url}.")
else:
returned_messages.append(msg)
else:
returned_messages.append(msg)
return returned_messages