feat: Add new API endpoints and HTML pages for ML model management

- Implemented HTML pages for datasets, models, training, testing, and results.
- Created API endpoints for managing repositories, results, tests, and training sessions.
- Added functionality for streaming training progress via Server-Sent Events (SSE).
- Introduced a Dockerfile for the ML runner with necessary dependencies.
- Developed an SDK for user code execution within the runner container.
- Enhanced CSS styles for improved UI layout and navigation.
- Established a layout template for consistent HTML structure across pages.
- Added JavaScript for dynamic interactions on the models page.
- Implemented WebSocket handling for real-time communication with kiosk devices and controllers.
- Implemented model registration and management API at /api/models
- Added Gitea proxy API for repository interactions at /api/repos
- Created results API for listing and comparing training results at /api/results
- Developed training management API for enqueueing and retrieving training jobs at /api/trainings
- Introduced SSE endpoint for live training progress updates
- Added HTML pages for models, datasets, and training management
- Created a Dockerfile for the ML runner with necessary dependencies
- Developed SDK for user code execution within the runner container
- Enhanced CSS styles for improved UI/UX
- Implemented WebSocket communication for real-time device and controller interactions in the kiosk system
This commit is contained in:
Giuseppe Raffa
2026-04-28 09:24:38 +02:00
parent ee478e52ef
commit 0ce879aa44
81 changed files with 7491 additions and 746 deletions

160
ml/routers/datasets.py Normal file
View File

@@ -0,0 +1,160 @@
"""API datasets (ml.mebboat.it/api/datasets).
Upload/list/get/download/delete. Storage:
MinIO bucket "ml" con key "datasets/<uuid>.<ext>"
Postgres db "ml" tabella "datasets"
"""
from __future__ import annotations
import json
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile
from core import db, minio_client
from core.auth import require_auth
router = APIRouter(prefix="/api/datasets", tags=["datasets"])
# Bucket MinIO fisso per tutti i dataset (no prefix nelle key).
BUCKET = "ml.datasets"
_EXT = {"csv": "csv", "json": "json", "netcdf": "nc"}
def _row(r) -> dict:
if r is None:
return None
d = dict(r)
# asyncpg ritorna JSONB come dict già; date/time come datetime
for k in ("created_at", "updated_at", "start_date", "end_date"):
if d.get(k) is not None and hasattr(d[k], "isoformat"):
d[k] = d[k].isoformat()
return d
@router.get("")
async def list_datasets(
type: Optional[str] = Query(None),
tags: Optional[str] = Query(None),
mine: Optional[int] = Query(None),
search: Optional[str] = Query(None),
user=Depends(require_auth),
):
where = []
args: list = []
if type:
args.append(type)
where.append(f"type = ${len(args)}")
if tags:
tag_arr = [t.strip() for t in tags.split(",") if t.strip()]
if tag_arr:
args.append(tag_arr)
where.append(f"tags && ${len(args)}")
if mine and user.get("username"):
args.append(user["username"])
where.append(f"created_by = ${len(args)}")
if search:
args.append(f"%{search}%")
where.append(f"(nome ILIKE ${len(args)} OR description ILIKE ${len(args)})")
sql = "SELECT * FROM datasets"
if where:
sql += " WHERE " + " AND ".join(where)
sql += " ORDER BY created_at DESC LIMIT 500"
rows = await db.fetch(sql, *args)
return {"count": len(rows), "datasets": [_row(r) for r in rows]}
@router.post("", status_code=201)
async def upload_dataset(
file: UploadFile = File(...),
metadata: str = Form("{}"),
user=Depends(require_auth),
):
try:
meta = json.loads(metadata or "{}")
except json.JSONDecodeError:
raise HTTPException(400, "metadata must be valid JSON")
fmt = meta.get("format") or meta.get("type") or "csv"
if fmt not in ("csv", "json", "netcdf"):
fmt = "csv"
ext = _EXT[fmt]
ds_id = str(uuid.uuid4())
file_key = f"{ds_id}.{ext}"
data = await file.read()
minio_client.put_bytes(file_key, data, content_type=file.content_type or "application/octet-stream", bucket=BUCKET)
created_by = user.get("username") or meta.get("created_by") or "unknown"
row = await db.fetchrow(
"""
INSERT INTO datasets (
id, file_key, nome, description, tags, type, format, notes,
created_by, size_bytes, copernicus_id
) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11)
RETURNING *
""",
uuid.UUID(ds_id),
file_key,
meta.get("nome") or file.filename or file_key,
meta.get("description"),
meta.get("tags") or [],
meta.get("dataset_type") or "custom",
fmt,
meta.get("notes"),
created_by,
len(data),
meta.get("copernicus_id") or meta.get("copernicus_dataset_id"),
)
return _row(row)
@router.get("/{dataset_id}")
async def get_dataset(dataset_id: str, user=Depends(require_auth)):
row = await db.fetchrow("SELECT * FROM datasets WHERE id = $1", uuid.UUID(dataset_id))
if not row:
raise HTTPException(404, "not found")
return _row(row)
@router.get("/{dataset_id}/download")
async def download_dataset(dataset_id: str, user=Depends(require_auth)):
row = await db.fetchrow("SELECT file_key FROM datasets WHERE id = $1", uuid.UUID(dataset_id))
if not row:
raise HTTPException(404, "not found")
url = minio_client.presigned_get(row["file_key"], 3600, bucket=BUCKET)
return {"url": url, "expires_in": 3600}
@router.patch("/{dataset_id}")
async def patch_dataset(dataset_id: str, body: dict, user=Depends(require_auth)):
allowed = {"nome", "description", "tags", "notes"}
sets = []
args: list = []
for k, v in body.items():
if k in allowed:
args.append(v)
sets.append(f"{k} = ${len(args)}")
if not sets:
raise HTTPException(400, "no fields to update")
# Trigger updated_at non presente nel DB: lo aggiorniamo manualmente.
sets.append("updated_at = NOW()")
args.append(uuid.UUID(dataset_id))
row = await db.fetchrow(
f"UPDATE datasets SET {', '.join(sets)} WHERE id = ${len(args)} RETURNING *",
*args,
)
if not row:
raise HTTPException(404, "not found")
return _row(row)
@router.delete("/{dataset_id}", status_code=204)
async def delete_dataset(dataset_id: str, user=Depends(require_auth)):
row = await db.fetchrow("SELECT file_key FROM datasets WHERE id = $1", uuid.UUID(dataset_id))
if not row:
raise HTTPException(404, "not found")
minio_client.remove(row["file_key"], bucket=BUCKET)
await db.execute("DELETE FROM datasets WHERE id = $1", uuid.UUID(dataset_id))
return None

131
ml/routers/models.py Normal file
View File

@@ -0,0 +1,131 @@
"""API /api/models — registro modelli (repo Gitea + metadata)."""
from __future__ import annotations
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from core import db
from core.auth import require_auth
from core.model_spec import fetch_and_parse_spec
router = APIRouter(prefix="/api/models", tags=["models"])
def _row(r) -> Optional[dict]:
if r is None:
return None
d = dict(r)
for k in ("created_at", "updated_at"):
if d.get(k) is not None and hasattr(d[k], "isoformat"):
d[k] = d[k].isoformat()
return d
@router.get("")
async def list_models(user=Depends(require_auth)):
rows = await db.fetch("SELECT * FROM models ORDER BY created_at DESC LIMIT 500")
return {"count": len(rows), "models": [_row(r) for r in rows]}
@router.post("", status_code=201)
async def create_model(body: dict, user=Depends(require_auth)):
required = ("name", "type", "gitea_repo")
for k in required:
if not body.get(k):
raise HTTPException(400, f"missing field: {k}")
# prova a pre-caricare model.yml dal default branch (non fatale)
spec = None
try:
spec = await fetch_and_parse_spec(body["gitea_repo"], body.get("default_branch") or "main")
except Exception:
spec = None
row = await db.fetchrow(
"""
INSERT INTO models (name, type, gitea_repo, default_branch, spec, created_by)
VALUES ($1,$2,$3,$4,$5,$6)
RETURNING *
""",
body["name"],
body["type"],
body["gitea_repo"],
body.get("default_branch") or "main",
spec,
user.get("username") or "unknown",
)
return _row(row)
@router.get("/{model_id}")
async def get_model(model_id: str, user=Depends(require_auth)):
row = await db.fetchrow("SELECT * FROM models WHERE id = $1", uuid.UUID(model_id))
if not row:
raise HTTPException(404, "not found")
return _row(row)
@router.patch("/{model_id}")
async def patch_model(model_id: str, body: dict, user=Depends(require_auth)):
allowed = {"name", "type", "default_branch"}
sets = []
args: list = []
for k, v in body.items():
if k in allowed:
args.append(v)
sets.append(f"{k} = ${len(args)}")
if not sets:
raise HTTPException(400, "no fields to update")
args.append(uuid.UUID(model_id))
row = await db.fetchrow(
f"UPDATE models SET {', '.join(sets)} WHERE id = ${len(args)} RETURNING *",
*args,
)
if not row:
raise HTTPException(404, "not found")
return _row(row)
@router.delete("/{model_id}", status_code=204)
async def delete_model(model_id: str, user=Depends(require_auth)):
await db.execute("DELETE FROM models WHERE id = $1", uuid.UUID(model_id))
return None
# ── Notes ──────────────────────────────────────────────────────────────────
@router.get("/{model_id}/notes")
async def list_notes(model_id: str, user=Depends(require_auth)):
rows = await db.fetch(
"SELECT id, author, text, created_at FROM model_notes WHERE model_id = $1 ORDER BY created_at DESC",
uuid.UUID(model_id),
)
return [
{
"id": str(r["id"]),
"author": r["author"],
"text": r["text"],
"created_at": r["created_at"].isoformat(),
}
for r in rows
]
@router.post("/{model_id}/notes", status_code=201)
async def add_note(model_id: str, body: dict, user=Depends(require_auth)):
text = (body.get("text") or "").strip()
if not text:
raise HTTPException(400, "text required")
row = await db.fetchrow(
"INSERT INTO model_notes (model_id, author, text) VALUES ($1, $2, $3) RETURNING *",
uuid.UUID(model_id),
user.get("username") or "unknown",
text,
)
return {
"id": str(row["id"]),
"author": row["author"],
"text": row["text"],
"created_at": row["created_at"].isoformat(),
}

75
ml/routers/pages.py Normal file
View File

@@ -0,0 +1,75 @@
"""Pagine HTML servite direttamente da ml.mebboat.it.
Layout:
/ redirect a /datasets (o landing console)
/datasets lista/upload dataset
/models registro modelli
/train avvia training
/test esegue test su modello trainato
/results storico e confronto risultati
"""
from __future__ import annotations
from pathlib import Path
from fastapi import APIRouter, Depends, Request
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from core.auth import _verify
from core.config import settings
router = APIRouter(tags=["pages"])
TEMPLATES_DIR = Path(__file__).resolve().parent.parent / "templates"
templates = Jinja2Templates(directory=str(TEMPLATES_DIR))
def _user_or_redirect(request: Request):
"""Per le pagine, se non autenticato redirect al login. Ritorna user dict o RedirectResponse."""
token = request.cookies.get("auth_token")
auth = request.headers.get("authorization")
if not token and auth and auth.startswith("Bearer "):
token = auth[7:]
user = _verify(token)
if not user:
target = str(request.url)
return RedirectResponse(url=f"{settings.auth_login_url}?redirect={target}", status_code=302)
return user
def _render(request: Request, template: str, **ctx):
user = _user_or_redirect(request)
if isinstance(user, RedirectResponse):
return user
return templates.TemplateResponse(template, {"request": request, "user": user, **ctx})
@router.get("/", response_class=HTMLResponse)
async def home(request: Request):
return RedirectResponse(url="/datasets")
@router.get("/datasets", response_class=HTMLResponse)
async def page_datasets(request: Request):
return _render(request, "datasets.html", page="datasets")
@router.get("/models", response_class=HTMLResponse)
async def page_models(request: Request):
return _render(request, "models.html", page="models")
@router.get("/train", response_class=HTMLResponse)
async def page_train(request: Request):
return _render(request, "train.html", page="train")
@router.get("/test", response_class=HTMLResponse)
async def page_test(request: Request):
return _render(request, "test.html", page="test")
@router.get("/results", response_class=HTMLResponse)
async def page_results(request: Request):
return _render(request, "results.html", page="results")

51
ml/routers/repos.py Normal file
View File

@@ -0,0 +1,51 @@
"""API /api/repos — proxy autenticato verso Gitea."""
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, Query
from core import gitea
from core.auth import require_auth
from core.model_spec import fetch_and_parse_spec
router = APIRouter(prefix="/api/repos", tags=["repos"])
@router.get("")
async def list_repos(user=Depends(require_auth)):
try:
return await gitea.list_repos()
except Exception as e:
raise HTTPException(502, f"gitea: {e}")
@router.get("/{owner}/{repo}/branches")
async def branches(owner: str, repo: str, user=Depends(require_auth)):
try:
return await gitea.list_branches(f"{owner}/{repo}")
except Exception as e:
raise HTTPException(502, f"gitea: {e}")
@router.get("/{owner}/{repo}/commits")
async def commits(owner: str, repo: str, branch: str = Query("main"), user=Depends(require_auth)):
try:
return await gitea.list_commits(f"{owner}/{repo}", branch)
except Exception as e:
raise HTTPException(502, f"gitea: {e}")
@router.get("/{owner}/{repo}/file")
async def file_raw(owner: str, repo: str, ref: str, path: str, user=Depends(require_auth)):
try:
raw = await gitea.get_file_raw(f"{owner}/{repo}", ref, path)
return {"content": raw.decode("utf-8", errors="replace"), "size": len(raw)}
except Exception as e:
raise HTTPException(404, f"file not found: {e}")
@router.get("/{owner}/{repo}/spec")
async def spec(owner: str, repo: str, ref: str = Query("main"), user=Depends(require_auth)):
s = await fetch_and_parse_spec(f"{owner}/{repo}", ref)
if s is None:
raise HTTPException(404, "model.yml not found at ref")
return s

89
ml/routers/results.py Normal file
View File

@@ -0,0 +1,89 @@
"""API /api/results — lista trainings/tests + compare multi-training."""
from __future__ import annotations
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from core import db, influx_client
from core.auth import require_auth
from core.config import settings
router = APIRouter(prefix="/api/results", tags=["results"])
def _row(r):
if r is None:
return None
d = dict(r)
for k in ("queued_at", "started_at", "finished_at", "started_at", "ended_at"):
if d.get(k) is not None and hasattr(d[k], "isoformat"):
d[k] = d[k].isoformat()
return d
@router.get("")
async def list_results(
model_id: Optional[str] = Query(None),
user=Depends(require_auth),
):
where = []
args: list = []
if model_id:
args.append(uuid.UUID(model_id))
where.append(f"model_id = ${len(args)}")
sql = "SELECT * FROM trainings"
if where:
sql += " WHERE " + " AND ".join(where)
sql += " ORDER BY finished_at DESC NULLS LAST, queued_at DESC LIMIT 200"
rows = await db.fetch(sql, *args)
return {"count": len(rows), "trainings": [_row(r) for r in rows]}
@router.get("/{training_id}")
async def get_result(training_id: str, user=Depends(require_auth)):
row = await db.fetchrow("SELECT * FROM trainings WHERE id = $1", uuid.UUID(training_id))
if not row:
raise HTTPException(404, "not found")
# timeseries via Influx: loss per iter + cpu/mem
flux = (
f'from(bucket:"{settings.influx_bucket}") '
f'|> range(start:-90d) '
f'|> filter(fn: (r) => r._measurement == "ml_training" and r.training_id == "{training_id}")'
)
try:
ts = await influx_client.query_flux(flux)
except Exception:
ts = []
return {"training": _row(row), "timeseries": ts}
@router.get("/compare")
async def compare(
trainings: str = Query(..., description="comma-separated training IDs"),
user=Depends(require_auth),
):
ids = [s.strip() for s in trainings.split(",") if s.strip()]
if len(ids) < 2:
raise HTTPException(400, "at least 2 training IDs required")
out = []
for tid in ids:
try:
tid_uuid = uuid.UUID(tid)
except ValueError:
continue
row = await db.fetchrow("SELECT * FROM trainings WHERE id = $1", tid_uuid)
if not row:
continue
flux = (
f'from(bucket:"{settings.influx_bucket}") '
f'|> range(start:-90d) '
f'|> filter(fn: (r) => r._measurement == "ml_training" and r.training_id == "{tid}")'
)
try:
ts = await influx_client.query_flux(flux)
except Exception:
ts = []
out.append({"training": _row(row), "timeseries": ts})
return {"results": out}

109
ml/routers/tests.py Normal file
View File

@@ -0,0 +1,109 @@
"""API /api/tests — sessioni di test su training esistente (max 2 utenti simultanei)."""
from __future__ import annotations
import json
import time
import uuid
from typing import Optional
import httpx
from fastapi import APIRouter, Depends, HTTPException
from core import api_client, db, minio_client
from core.auth import require_auth
from core.docker_runner import run_test_once
router = APIRouter(prefix="/api/tests", tags=["tests"])
def _row(r):
if r is None:
return None
d = dict(r)
for k in ("started_at", "ended_at"):
if d.get(k) is not None and hasattr(d[k], "isoformat"):
d[k] = d[k].isoformat()
return d
@router.post("/sessions", status_code=201)
async def start_session(body: dict, user=Depends(require_auth)):
training_id = body.get("training_id")
if not training_id:
raise HTTPException(400, "training_id required")
tr = await db.fetchrow(
"SELECT id, status FROM trainings WHERE id = $1", uuid.UUID(training_id)
)
if not tr:
raise HTTPException(404, "training not found")
if tr["status"] != "succeeded":
raise HTTPException(409, "training not completed")
sid = str(uuid.uuid4())
try:
await api_client.page_connect("test", user.get("username") or "unknown", sid)
except httpx.HTTPStatusError as e:
if e.response.status_code == 429:
raise HTTPException(429, "test slots full (max 2 users)")
raise HTTPException(502, f"api: {e}")
row = await db.fetchrow(
"INSERT INTO tests (id, training_id, user_id) VALUES ($1,$2,$3) RETURNING *",
uuid.UUID(sid),
uuid.UUID(training_id),
user.get("username") or "unknown",
)
return _row(row)
@router.post("/sessions/{session_id}/ping")
async def ping_session(session_id: str, user=Depends(require_auth)):
try:
await api_client.page_ping(session_id)
except httpx.HTTPStatusError as e:
raise HTTPException(e.response.status_code, e.response.text)
return {"ok": True}
@router.post("/sessions/{session_id}/runs", status_code=201)
async def run_test(session_id: str, body: dict, user=Depends(require_auth)):
row = await db.fetchrow("SELECT * FROM tests WHERE id = $1", uuid.UUID(session_id))
if not row:
raise HTTPException(404, "session not found")
inputs = body.get("inputs") or {}
t0 = time.monotonic()
try:
result = await run_test_once(str(row["training_id"]), inputs)
except Exception as e:
raise HTTPException(500, f"test run failed: {e}")
dt_ms = int((time.monotonic() - t0) * 1000)
run = {
"inputs": inputs,
"outputs": result.get("outputs", {}),
"duration_ms": dt_ms,
"cpu_peak": result.get("cpu_peak"),
"mem_peak_mb": result.get("mem_peak_mb"),
"ts": time.time(),
}
await db.execute(
"UPDATE tests SET runs = runs || $1::jsonb WHERE id = $2",
json.dumps([run]),
uuid.UUID(session_id),
)
return run
@router.delete("/sessions/{session_id}", status_code=204)
async def end_session(session_id: str, user=Depends(require_auth)):
await db.execute(
"UPDATE tests SET ended_at = NOW() WHERE id = $1 AND ended_at IS NULL",
uuid.UUID(session_id),
)
try:
await api_client.page_disconnect(session_id)
except Exception:
pass
return None

129
ml/routers/trainings.py Normal file
View File

@@ -0,0 +1,129 @@
"""API /api/trainings — enqueue, list, get, artifacts."""
from __future__ import annotations
import json
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from core import db, minio_client, redis_client, api_client
from core.auth import require_auth
router = APIRouter(prefix="/api/trainings", tags=["trainings"])
def _row(r) -> Optional[dict]:
if r is None:
return None
d = dict(r)
for k in ("queued_at", "started_at", "finished_at"):
if d.get(k) is not None and hasattr(d[k], "isoformat"):
d[k] = d[k].isoformat()
return d
@router.get("")
async def list_trainings(
model_id: Optional[str] = Query(None),
status: Optional[str] = Query(None),
limit: int = Query(100, le=500),
user=Depends(require_auth),
):
where = []
args: list = []
if model_id:
args.append(uuid.UUID(model_id))
where.append(f"model_id = ${len(args)}")
if status:
args.append(status)
where.append(f"status = ${len(args)}")
sql = "SELECT * FROM trainings"
if where:
sql += " WHERE " + " AND ".join(where)
args.append(limit)
sql += f" ORDER BY queued_at DESC LIMIT ${len(args)}"
rows = await db.fetch(sql, *args)
return {"count": len(rows), "trainings": [_row(r) for r in rows]}
@router.post("", status_code=202)
async def enqueue_training(body: dict, user=Depends(require_auth)):
for k in ("model_id", "version", "patch", "dataset_id"):
if not body.get(k):
raise HTTPException(400, f"missing field: {k}")
model_row = await db.fetchrow("SELECT * FROM models WHERE id = $1", uuid.UUID(body["model_id"]))
if not model_row:
raise HTTPException(404, "model not found")
ds_row = await db.fetchrow("SELECT id FROM datasets WHERE id = $1", uuid.UUID(body["dataset_id"]))
if not ds_row:
raise HTTPException(404, "dataset not found")
try:
training_row = await db.fetchrow(
"""
INSERT INTO trainings (model_id, version, patch, dataset_id, started_by, status)
VALUES ($1,$2,$3,$4,$5,'queued')
RETURNING *
""",
uuid.UUID(body["model_id"]),
body["version"],
body["patch"],
uuid.UUID(body["dataset_id"]),
user.get("username") or "unknown",
)
except Exception as e:
raise HTTPException(409, f"training already exists or invalid: {e}")
training_id = str(training_row["id"])
# crea job lato api-service (cross-service registry)
try:
await api_client.create_job(
"train",
created_by=user.get("username") or "unknown",
payload={
"training_id": training_id,
"model_id": body["model_id"],
"version": body["version"],
"patch": body["patch"],
"dataset_id": body["dataset_id"],
},
)
except Exception as e:
# non-fatale: il worker locale può comunque procedere; logghiamo e continuiamo
import logging
logging.warning("create_job failed: %s", e)
# enqueue in Redis (il worker locale lo raccoglie)
await redis_client.client().lpush("ml:queue:train", training_id)
await redis_client.client().hset(
f"ml:train:{training_id}",
mapping={"status": "queued", "progress": "0", "message": "queued"},
)
await redis_client.client().expire(f"ml:train:{training_id}", 48 * 3600)
return _row(training_row)
@router.get("/{training_id}")
async def get_training(training_id: str, user=Depends(require_auth)):
row = await db.fetchrow("SELECT * FROM trainings WHERE id = $1", uuid.UUID(training_id))
if not row:
raise HTTPException(404, "not found")
return _row(row)
@router.get("/{training_id}/artifacts")
async def list_artifacts(training_id: str, user=Depends(require_auth)):
row = await db.fetchrow(
"SELECT artifacts_prefix FROM trainings WHERE id = $1", uuid.UUID(training_id)
)
if not row or not row["artifacts_prefix"]:
raise HTTPException(404, "no artifacts")
objs = minio_client.list_prefix(row["artifacts_prefix"] + "/")
for o in objs:
o["url"] = minio_client.presigned_get(o["name"], 3600)
return objs

View File

@@ -0,0 +1,64 @@
"""SSE endpoint per live progress del training.
GET /api/trainings/{id}/events
Streamma eventi dal Redis stream `ml:train:{id}:events` via Server-Sent Events.
Termina quando lo stato del training è terminale (succeeded/failed/cancelled).
"""
from __future__ import annotations
import asyncio
import json
import uuid
from fastapi import APIRouter, Depends, HTTPException
from sse_starlette.sse import EventSourceResponse
from core import db, redis_client
from core.auth import require_auth
router = APIRouter(prefix="/api/trainings", tags=["trainings-sse"])
_TERMINAL = {"succeeded", "failed", "cancelled"}
@router.get("/{training_id}/events")
async def training_events(training_id: str, user=Depends(require_auth)):
# verifica esistenza
row = await db.fetchrow("SELECT status FROM trainings WHERE id = $1", uuid.UUID(training_id))
if not row:
raise HTTPException(404, "not found")
stream_key = f"ml:train:{training_id}:events"
status_key = f"ml:train:{training_id}"
async def gen():
last_id = "0-0"
r = redis_client.client()
while True:
try:
# XREAD block 5s per non tenere la connessione idle troppo a lungo
resp = await r.xread({stream_key: last_id}, count=50, block=5000)
except Exception as e:
yield {"event": "error", "data": json.dumps({"error": str(e)})}
await asyncio.sleep(1)
continue
if resp:
for _stream, entries in resp:
for entry_id, fields in entries:
last_id = entry_id
yield {"event": "message", "id": entry_id, "data": json.dumps(fields)}
# controlla stato terminale
state = await r.hget(status_key, "status")
if not state:
# fallback su db se redis scaduto
db_row = await db.fetchrow(
"SELECT status FROM trainings WHERE id = $1", uuid.UUID(training_id)
)
state = db_row["status"] if db_row else "unknown"
if state in _TERMINAL:
yield {"event": "end", "data": json.dumps({"status": state})}
return
return EventSourceResponse(gen())