Update Categorisation Flow (#2922)
Co-authored-by: Deshraj Yadav <deshraj@gatech.edu>
This commit is contained in:
@@ -1,15 +1,13 @@
|
|||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from openai import OpenAI
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from openai import OpenAI
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||||
from app.utils.prompts import MEMORY_CATEGORIZATION_PROMPT
|
from app.utils.prompts import MEMORY_CATEGORIZATION_PROMPT
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
openai_client = OpenAI()
|
openai_client = OpenAI()
|
||||||
|
|
||||||
|
|
||||||
@@ -19,19 +17,28 @@ class MemoryCategories(BaseModel):
|
|||||||
|
|
||||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=15))
|
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=15))
|
||||||
def get_categories_for_memory(memory: str) -> List[str]:
|
def get_categories_for_memory(memory: str) -> List[str]:
|
||||||
"""Get categories for a memory."""
|
|
||||||
try:
|
try:
|
||||||
response = openai_client.responses.parse(
|
messages = [
|
||||||
|
{"role": "system", "content": MEMORY_CATEGORIZATION_PROMPT},
|
||||||
|
{"role": "user", "content": memory}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Let OpenAI handle the pydantic parsing directly
|
||||||
|
completion = openai_client.chat.completions.with_response_format(
|
||||||
|
response_format=MemoryCategories
|
||||||
|
).create(
|
||||||
model="gpt-4o-mini",
|
model="gpt-4o-mini",
|
||||||
instructions=MEMORY_CATEGORIZATION_PROMPT,
|
messages=messages,
|
||||||
input=memory,
|
temperature=0
|
||||||
temperature=0,
|
|
||||||
text_format=MemoryCategories,
|
|
||||||
)
|
)
|
||||||
response_json =json.loads(response.output[0].content[0].text)
|
|
||||||
categories = response_json['categories']
|
parsed: MemoryCategories = completion.choices[0].message.parsed
|
||||||
categories = [cat.strip().lower() for cat in categories]
|
return [cat.strip().lower() for cat in parsed.categories]
|
||||||
# TODO: Validate categories later may be
|
|
||||||
return categories
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
logging.error(f"[ERROR] Failed to get categories: {e}")
|
||||||
|
try:
|
||||||
|
logging.debug(f"[DEBUG] Raw response: {completion.choices[0].message.content}")
|
||||||
|
except Exception as debug_e:
|
||||||
|
logging.debug(f"[DEBUG] Could not extract raw response: {debug_e}")
|
||||||
|
raise
|
||||||
|
|||||||
Reference in New Issue
Block a user