"""Claude Sonnet 4 vision pipeline: legend detection + symbol counting.""" import base64 import json import logging import os from pathlib import Path from openai import AsyncOpenAI logger = logging.getLogger(__name__) _client: AsyncOpenAI | None = None def _get_client() -> AsyncOpenAI: global _client if _client is None: _client = AsyncOpenAI( base_url=os.getenv("LITELLM_BASE_URL", "http://litellm-proxy:4000/v1"), api_key=os.getenv("LITELLM_API_KEY", "sk-dummy"), ) return _client def _b64(path: Path) -> str: return base64.b64encode(path.read_bytes()).decode("ascii") MODEL = os.getenv("LLM_MODEL", "anthropic/claude-sonnet-4-20250514") LEGEND_PROMPT = """You are analyzing a Czech architectural / engineering drawing (HVAC, electrical, plumbing, fire safety, etc). NOTE: The drawing may be rotated 90° or 180°. Mentally rotate the page until text reads horizontally before searching. The legend is typically near a page edge or corner, often in a colored (yellow / red / green) text block that lists symbols and their meanings. Your task: find any legend / symbol key in this drawing. Possible Czech headings include "LEGENDA", "VYSVĚTLIVKY", "POPIS", "POPIS SYMBOLŮ", "POUŽITÉ ZNAČENÍ". A legend is a column-table where each row pairs a small graphical symbol with a Czech description. For each symbol entry found, return: - `id`: short stable identifier (1-3 words, lowercase with underscores; e.g. "smoke_detector", "socket_230v", "valve_3way") - `description`: full Czech description text exactly as written - `bbox`: bounding box of THE SYMBOL ONLY (NOT the description text), normalized 0-1 coords relative to the full image. Format `{"x": 0.05, "y": 0.10, "w": 0.02, "h": 0.02}` where x,y is the top-left corner. Return ONLY valid JSON, no markdown, no commentary: {"symbols": [{"id":"...","description":"...","bbox":{"x":0,"y":0,"w":0,"h":0}}]} If you genuinely cannot find any legend, return {"symbols": [], "reason": ""}. Skip rows that are room schedules, material totals, or project info — only include rows showing actual graphical symbols with descriptions.""" COUNT_PROMPT_TEMPLATE = """You are counting graphical symbols in a Czech architectural/engineering drawing. I will show you: 1. A reference symbol crop from the drawing's legend 2. The full drawing image Your task: count the number of times the reference symbol appears in the full drawing. Look only at the drawing area (not the legend itself). The symbol may be rotated 0°, 90°, 180°, 270° — count rotated instances. Ignore size variations within reason (the symbol scale should be similar). Reference symbol description (from legend): "{description}" Return ONLY valid JSON, no markdown: {{"count": , "confidence": <"low"|"medium"|"high">, "notes": ""}}""" async def detect_legend(image_path: Path) -> list[dict]: """Pass 1: Find legend, return list of symbols with bbox.""" img_b64 = _b64(image_path) resp = await _get_client().chat.completions.create( model=MODEL, messages=[{ "role": "user", "content": [ {"type": "text", "text": LEGEND_PROMPT}, {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}}, ], }], max_tokens=4000, temperature=0.0, ) raw = (resp.choices[0].message.content or "").strip() logger.info("Legend raw response (first 800 chars): %s", raw[:800]) raw = raw.removeprefix("```json").removeprefix("```").removesuffix("```").strip() try: data = json.loads(raw) except json.JSONDecodeError as e: logger.error("Legend JSON parse failed: %s\nraw=%s", e, raw[:500]) return [] symbols = data.get("symbols", []) if isinstance(data, dict) else [] logger.info("Legend detection found %d symbols", len(symbols)) return symbols async def count_symbol(symbol_crop: Path, full_image: Path, description: str) -> dict: """Pass 2: Count instances of one symbol in the drawing.""" crop_b64 = _b64(symbol_crop) full_b64 = _b64(full_image) prompt = COUNT_PROMPT_TEMPLATE.format(description=description) resp = await _get_client().chat.completions.create( model=MODEL, messages=[{ "role": "user", "content": [ {"type": "text", "text": prompt}, {"type": "text", "text": "Reference symbol:"}, {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{crop_b64}"}}, {"type": "text", "text": "Full drawing:"}, {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{full_b64}"}}, ], }], max_tokens=300, temperature=0.0, ) raw = (resp.choices[0].message.content or "").strip() raw = raw.removeprefix("```json").removeprefix("```").removesuffix("```").strip() try: return json.loads(raw) except json.JSONDecodeError: logger.error("Count JSON parse failed: %s", raw[:200]) return {"count": 0, "confidence": "low", "notes": "parse failed"}