Update Categorisation Flow (#2922)

Co-authored-by: Deshraj Yadav <deshraj@gatech.edu>
This commit is contained in:
Akshat Jain
2025-06-11 21:24:15 +05:30
committed by GitHub
parent aa334fb569
commit c59752c6d6

View File

@@ -1,15 +1,13 @@
import json
import logging
from openai import OpenAI
from typing import List
from dotenv import load_dotenv
from openai import OpenAI
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, wait_exponential
from app.utils.prompts import MEMORY_CATEGORIZATION_PROMPT
load_dotenv()
openai_client = OpenAI()
@@ -19,19 +17,28 @@ class MemoryCategories(BaseModel):
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=15))
def get_categories_for_memory(memory: str) -> List[str]:
"""Get categories for a memory."""
try:
response = openai_client.responses.parse(
messages = [
{"role": "system", "content": MEMORY_CATEGORIZATION_PROMPT},
{"role": "user", "content": memory}
]
# Let OpenAI handle the pydantic parsing directly
completion = openai_client.chat.completions.with_response_format(
response_format=MemoryCategories
).create(
model="gpt-4o-mini",
instructions=MEMORY_CATEGORIZATION_PROMPT,
input=memory,
temperature=0,
text_format=MemoryCategories,
messages=messages,
temperature=0
)
response_json =json.loads(response.output[0].content[0].text)
categories = response_json['categories']
categories = [cat.strip().lower() for cat in categories]
# TODO: Validate categories later may be
return categories
parsed: MemoryCategories = completion.choices[0].message.parsed
return [cat.strip().lower() for cat in parsed.categories]
except Exception as e:
raise e
logging.error(f"[ERROR] Failed to get categories: {e}")
try:
logging.debug(f"[DEBUG] Raw response: {completion.choices[0].message.content}")
except Exception as debug_e:
logging.debug(f"[DEBUG] Could not extract raw response: {debug_e}")
raise