"""SSE endpoint per live progress del training. GET /api/trainings/{id}/events Streamma eventi dal Redis stream `ml:train:{id}:events` via Server-Sent Events. Termina quando lo stato del training รจ terminale (succeeded/failed/cancelled). """ from __future__ import annotations import asyncio import json import uuid from fastapi import APIRouter, Depends, HTTPException from sse_starlette.sse import EventSourceResponse from core import db, redis_client from core.auth import require_auth router = APIRouter(prefix="/api/trainings", tags=["trainings-sse"]) _TERMINAL = {"succeeded", "failed", "cancelled"} @router.get("/{training_id}/events") async def training_events(training_id: str, user=Depends(require_auth)): # verifica esistenza row = await db.fetchrow("SELECT status FROM trainings WHERE id = $1", uuid.UUID(training_id)) if not row: raise HTTPException(404, "not found") stream_key = f"ml:train:{training_id}:events" status_key = f"ml:train:{training_id}" async def gen(): last_id = "0-0" r = redis_client.client() while True: try: # XREAD block 5s per non tenere la connessione idle troppo a lungo resp = await r.xread({stream_key: last_id}, count=50, block=5000) except Exception as e: yield {"event": "error", "data": json.dumps({"error": str(e)})} await asyncio.sleep(1) continue if resp: for _stream, entries in resp: for entry_id, fields in entries: last_id = entry_id yield {"event": "message", "id": entry_id, "data": json.dumps(fields)} # controlla stato terminale state = await r.hget(status_key, "status") if not state: # fallback su db se redis scaduto db_row = await db.fetchrow( "SELECT status FROM trainings WHERE id = $1", uuid.UUID(training_id) ) state = db_row["status"] if db_row else "unknown" if state in _TERMINAL: yield {"event": "end", "data": json.dumps({"status": state})} return return EventSourceResponse(gen())