feat: Add new API endpoints and HTML pages for ML model management
- Implemented HTML pages for datasets, models, training, testing, and results. - Created API endpoints for managing repositories, results, tests, and training sessions. - Added functionality for streaming training progress via Server-Sent Events (SSE). - Introduced a Dockerfile for the ML runner with necessary dependencies. - Developed an SDK for user code execution within the runner container. - Enhanced CSS styles for improved UI layout and navigation. - Established a layout template for consistent HTML structure across pages. - Added JavaScript for dynamic interactions on the models page. - Implemented WebSocket handling for real-time communication with kiosk devices and controllers. - Implemented model registration and management API at /api/models - Added Gitea proxy API for repository interactions at /api/repos - Created results API for listing and comparing training results at /api/results - Developed training management API for enqueueing and retrieving training jobs at /api/trainings - Introduced SSE endpoint for live training progress updates - Added HTML pages for models, datasets, and training management - Created a Dockerfile for the ML runner with necessary dependencies - Developed SDK for user code execution within the runner container - Enhanced CSS styles for improved UI/UX - Implemented WebSocket communication for real-time device and controller interactions in the kiosk system
This commit is contained in:
160
ml/routers/datasets.py
Normal file
160
ml/routers/datasets.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""API datasets (ml.mebboat.it/api/datasets).
|
||||
|
||||
Upload/list/get/download/delete. Storage:
|
||||
MinIO bucket "ml" con key "datasets/<uuid>.<ext>"
|
||||
Postgres db "ml" tabella "datasets"
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile
|
||||
|
||||
from core import db, minio_client
|
||||
from core.auth import require_auth
|
||||
|
||||
router = APIRouter(prefix="/api/datasets", tags=["datasets"])
|
||||
|
||||
# Bucket MinIO fisso per tutti i dataset (no prefix nelle key).
|
||||
BUCKET = "ml.datasets"
|
||||
_EXT = {"csv": "csv", "json": "json", "netcdf": "nc"}
|
||||
|
||||
|
||||
def _row(r) -> dict:
|
||||
if r is None:
|
||||
return None
|
||||
d = dict(r)
|
||||
# asyncpg ritorna JSONB come dict già; date/time come datetime
|
||||
for k in ("created_at", "updated_at", "start_date", "end_date"):
|
||||
if d.get(k) is not None and hasattr(d[k], "isoformat"):
|
||||
d[k] = d[k].isoformat()
|
||||
return d
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_datasets(
|
||||
type: Optional[str] = Query(None),
|
||||
tags: Optional[str] = Query(None),
|
||||
mine: Optional[int] = Query(None),
|
||||
search: Optional[str] = Query(None),
|
||||
user=Depends(require_auth),
|
||||
):
|
||||
where = []
|
||||
args: list = []
|
||||
if type:
|
||||
args.append(type)
|
||||
where.append(f"type = ${len(args)}")
|
||||
if tags:
|
||||
tag_arr = [t.strip() for t in tags.split(",") if t.strip()]
|
||||
if tag_arr:
|
||||
args.append(tag_arr)
|
||||
where.append(f"tags && ${len(args)}")
|
||||
if mine and user.get("username"):
|
||||
args.append(user["username"])
|
||||
where.append(f"created_by = ${len(args)}")
|
||||
if search:
|
||||
args.append(f"%{search}%")
|
||||
where.append(f"(nome ILIKE ${len(args)} OR description ILIKE ${len(args)})")
|
||||
sql = "SELECT * FROM datasets"
|
||||
if where:
|
||||
sql += " WHERE " + " AND ".join(where)
|
||||
sql += " ORDER BY created_at DESC LIMIT 500"
|
||||
rows = await db.fetch(sql, *args)
|
||||
return {"count": len(rows), "datasets": [_row(r) for r in rows]}
|
||||
|
||||
|
||||
@router.post("", status_code=201)
|
||||
async def upload_dataset(
|
||||
file: UploadFile = File(...),
|
||||
metadata: str = Form("{}"),
|
||||
user=Depends(require_auth),
|
||||
):
|
||||
try:
|
||||
meta = json.loads(metadata or "{}")
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(400, "metadata must be valid JSON")
|
||||
|
||||
fmt = meta.get("format") or meta.get("type") or "csv"
|
||||
if fmt not in ("csv", "json", "netcdf"):
|
||||
fmt = "csv"
|
||||
ext = _EXT[fmt]
|
||||
ds_id = str(uuid.uuid4())
|
||||
file_key = f"{ds_id}.{ext}"
|
||||
|
||||
data = await file.read()
|
||||
minio_client.put_bytes(file_key, data, content_type=file.content_type or "application/octet-stream", bucket=BUCKET)
|
||||
|
||||
created_by = user.get("username") or meta.get("created_by") or "unknown"
|
||||
row = await db.fetchrow(
|
||||
"""
|
||||
INSERT INTO datasets (
|
||||
id, file_key, nome, description, tags, type, format, notes,
|
||||
created_by, size_bytes, copernicus_id
|
||||
) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11)
|
||||
RETURNING *
|
||||
""",
|
||||
uuid.UUID(ds_id),
|
||||
file_key,
|
||||
meta.get("nome") or file.filename or file_key,
|
||||
meta.get("description"),
|
||||
meta.get("tags") or [],
|
||||
meta.get("dataset_type") or "custom",
|
||||
fmt,
|
||||
meta.get("notes"),
|
||||
created_by,
|
||||
len(data),
|
||||
meta.get("copernicus_id") or meta.get("copernicus_dataset_id"),
|
||||
)
|
||||
return _row(row)
|
||||
|
||||
|
||||
@router.get("/{dataset_id}")
|
||||
async def get_dataset(dataset_id: str, user=Depends(require_auth)):
|
||||
row = await db.fetchrow("SELECT * FROM datasets WHERE id = $1", uuid.UUID(dataset_id))
|
||||
if not row:
|
||||
raise HTTPException(404, "not found")
|
||||
return _row(row)
|
||||
|
||||
|
||||
@router.get("/{dataset_id}/download")
|
||||
async def download_dataset(dataset_id: str, user=Depends(require_auth)):
|
||||
row = await db.fetchrow("SELECT file_key FROM datasets WHERE id = $1", uuid.UUID(dataset_id))
|
||||
if not row:
|
||||
raise HTTPException(404, "not found")
|
||||
url = minio_client.presigned_get(row["file_key"], 3600, bucket=BUCKET)
|
||||
return {"url": url, "expires_in": 3600}
|
||||
|
||||
|
||||
@router.patch("/{dataset_id}")
|
||||
async def patch_dataset(dataset_id: str, body: dict, user=Depends(require_auth)):
|
||||
allowed = {"nome", "description", "tags", "notes"}
|
||||
sets = []
|
||||
args: list = []
|
||||
for k, v in body.items():
|
||||
if k in allowed:
|
||||
args.append(v)
|
||||
sets.append(f"{k} = ${len(args)}")
|
||||
if not sets:
|
||||
raise HTTPException(400, "no fields to update")
|
||||
# Trigger updated_at non presente nel DB: lo aggiorniamo manualmente.
|
||||
sets.append("updated_at = NOW()")
|
||||
args.append(uuid.UUID(dataset_id))
|
||||
row = await db.fetchrow(
|
||||
f"UPDATE datasets SET {', '.join(sets)} WHERE id = ${len(args)} RETURNING *",
|
||||
*args,
|
||||
)
|
||||
if not row:
|
||||
raise HTTPException(404, "not found")
|
||||
return _row(row)
|
||||
|
||||
|
||||
@router.delete("/{dataset_id}", status_code=204)
|
||||
async def delete_dataset(dataset_id: str, user=Depends(require_auth)):
|
||||
row = await db.fetchrow("SELECT file_key FROM datasets WHERE id = $1", uuid.UUID(dataset_id))
|
||||
if not row:
|
||||
raise HTTPException(404, "not found")
|
||||
minio_client.remove(row["file_key"], bucket=BUCKET)
|
||||
await db.execute("DELETE FROM datasets WHERE id = $1", uuid.UUID(dataset_id))
|
||||
return None
|
||||
131
ml/routers/models.py
Normal file
131
ml/routers/models.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""API /api/models — registro modelli (repo Gitea + metadata)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from core import db
|
||||
from core.auth import require_auth
|
||||
from core.model_spec import fetch_and_parse_spec
|
||||
|
||||
router = APIRouter(prefix="/api/models", tags=["models"])
|
||||
|
||||
|
||||
def _row(r) -> Optional[dict]:
|
||||
if r is None:
|
||||
return None
|
||||
d = dict(r)
|
||||
for k in ("created_at", "updated_at"):
|
||||
if d.get(k) is not None and hasattr(d[k], "isoformat"):
|
||||
d[k] = d[k].isoformat()
|
||||
return d
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_models(user=Depends(require_auth)):
|
||||
rows = await db.fetch("SELECT * FROM models ORDER BY created_at DESC LIMIT 500")
|
||||
return {"count": len(rows), "models": [_row(r) for r in rows]}
|
||||
|
||||
|
||||
@router.post("", status_code=201)
|
||||
async def create_model(body: dict, user=Depends(require_auth)):
|
||||
required = ("name", "type", "gitea_repo")
|
||||
for k in required:
|
||||
if not body.get(k):
|
||||
raise HTTPException(400, f"missing field: {k}")
|
||||
|
||||
# prova a pre-caricare model.yml dal default branch (non fatale)
|
||||
spec = None
|
||||
try:
|
||||
spec = await fetch_and_parse_spec(body["gitea_repo"], body.get("default_branch") or "main")
|
||||
except Exception:
|
||||
spec = None
|
||||
|
||||
row = await db.fetchrow(
|
||||
"""
|
||||
INSERT INTO models (name, type, gitea_repo, default_branch, spec, created_by)
|
||||
VALUES ($1,$2,$3,$4,$5,$6)
|
||||
RETURNING *
|
||||
""",
|
||||
body["name"],
|
||||
body["type"],
|
||||
body["gitea_repo"],
|
||||
body.get("default_branch") or "main",
|
||||
spec,
|
||||
user.get("username") or "unknown",
|
||||
)
|
||||
return _row(row)
|
||||
|
||||
|
||||
@router.get("/{model_id}")
|
||||
async def get_model(model_id: str, user=Depends(require_auth)):
|
||||
row = await db.fetchrow("SELECT * FROM models WHERE id = $1", uuid.UUID(model_id))
|
||||
if not row:
|
||||
raise HTTPException(404, "not found")
|
||||
return _row(row)
|
||||
|
||||
|
||||
@router.patch("/{model_id}")
|
||||
async def patch_model(model_id: str, body: dict, user=Depends(require_auth)):
|
||||
allowed = {"name", "type", "default_branch"}
|
||||
sets = []
|
||||
args: list = []
|
||||
for k, v in body.items():
|
||||
if k in allowed:
|
||||
args.append(v)
|
||||
sets.append(f"{k} = ${len(args)}")
|
||||
if not sets:
|
||||
raise HTTPException(400, "no fields to update")
|
||||
args.append(uuid.UUID(model_id))
|
||||
row = await db.fetchrow(
|
||||
f"UPDATE models SET {', '.join(sets)} WHERE id = ${len(args)} RETURNING *",
|
||||
*args,
|
||||
)
|
||||
if not row:
|
||||
raise HTTPException(404, "not found")
|
||||
return _row(row)
|
||||
|
||||
|
||||
@router.delete("/{model_id}", status_code=204)
|
||||
async def delete_model(model_id: str, user=Depends(require_auth)):
|
||||
await db.execute("DELETE FROM models WHERE id = $1", uuid.UUID(model_id))
|
||||
return None
|
||||
|
||||
|
||||
# ── Notes ──────────────────────────────────────────────────────────────────
|
||||
@router.get("/{model_id}/notes")
|
||||
async def list_notes(model_id: str, user=Depends(require_auth)):
|
||||
rows = await db.fetch(
|
||||
"SELECT id, author, text, created_at FROM model_notes WHERE model_id = $1 ORDER BY created_at DESC",
|
||||
uuid.UUID(model_id),
|
||||
)
|
||||
return [
|
||||
{
|
||||
"id": str(r["id"]),
|
||||
"author": r["author"],
|
||||
"text": r["text"],
|
||||
"created_at": r["created_at"].isoformat(),
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
@router.post("/{model_id}/notes", status_code=201)
|
||||
async def add_note(model_id: str, body: dict, user=Depends(require_auth)):
|
||||
text = (body.get("text") or "").strip()
|
||||
if not text:
|
||||
raise HTTPException(400, "text required")
|
||||
row = await db.fetchrow(
|
||||
"INSERT INTO model_notes (model_id, author, text) VALUES ($1, $2, $3) RETURNING *",
|
||||
uuid.UUID(model_id),
|
||||
user.get("username") or "unknown",
|
||||
text,
|
||||
)
|
||||
return {
|
||||
"id": str(row["id"]),
|
||||
"author": row["author"],
|
||||
"text": row["text"],
|
||||
"created_at": row["created_at"].isoformat(),
|
||||
}
|
||||
75
ml/routers/pages.py
Normal file
75
ml/routers/pages.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Pagine HTML servite direttamente da ml.mebboat.it.
|
||||
|
||||
Layout:
|
||||
/ redirect a /datasets (o landing console)
|
||||
/datasets lista/upload dataset
|
||||
/models registro modelli
|
||||
/train avvia training
|
||||
/test esegue test su modello trainato
|
||||
/results storico e confronto risultati
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
from core.auth import _verify
|
||||
from core.config import settings
|
||||
|
||||
router = APIRouter(tags=["pages"])
|
||||
|
||||
TEMPLATES_DIR = Path(__file__).resolve().parent.parent / "templates"
|
||||
templates = Jinja2Templates(directory=str(TEMPLATES_DIR))
|
||||
|
||||
|
||||
def _user_or_redirect(request: Request):
|
||||
"""Per le pagine, se non autenticato redirect al login. Ritorna user dict o RedirectResponse."""
|
||||
token = request.cookies.get("auth_token")
|
||||
auth = request.headers.get("authorization")
|
||||
if not token and auth and auth.startswith("Bearer "):
|
||||
token = auth[7:]
|
||||
user = _verify(token)
|
||||
if not user:
|
||||
target = str(request.url)
|
||||
return RedirectResponse(url=f"{settings.auth_login_url}?redirect={target}", status_code=302)
|
||||
return user
|
||||
|
||||
|
||||
def _render(request: Request, template: str, **ctx):
|
||||
user = _user_or_redirect(request)
|
||||
if isinstance(user, RedirectResponse):
|
||||
return user
|
||||
return templates.TemplateResponse(template, {"request": request, "user": user, **ctx})
|
||||
|
||||
|
||||
@router.get("/", response_class=HTMLResponse)
|
||||
async def home(request: Request):
|
||||
return RedirectResponse(url="/datasets")
|
||||
|
||||
|
||||
@router.get("/datasets", response_class=HTMLResponse)
|
||||
async def page_datasets(request: Request):
|
||||
return _render(request, "datasets.html", page="datasets")
|
||||
|
||||
|
||||
@router.get("/models", response_class=HTMLResponse)
|
||||
async def page_models(request: Request):
|
||||
return _render(request, "models.html", page="models")
|
||||
|
||||
|
||||
@router.get("/train", response_class=HTMLResponse)
|
||||
async def page_train(request: Request):
|
||||
return _render(request, "train.html", page="train")
|
||||
|
||||
|
||||
@router.get("/test", response_class=HTMLResponse)
|
||||
async def page_test(request: Request):
|
||||
return _render(request, "test.html", page="test")
|
||||
|
||||
|
||||
@router.get("/results", response_class=HTMLResponse)
|
||||
async def page_results(request: Request):
|
||||
return _render(request, "results.html", page="results")
|
||||
51
ml/routers/repos.py
Normal file
51
ml/routers/repos.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""API /api/repos — proxy autenticato verso Gitea."""
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from core import gitea
|
||||
from core.auth import require_auth
|
||||
from core.model_spec import fetch_and_parse_spec
|
||||
|
||||
router = APIRouter(prefix="/api/repos", tags=["repos"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_repos(user=Depends(require_auth)):
|
||||
try:
|
||||
return await gitea.list_repos()
|
||||
except Exception as e:
|
||||
raise HTTPException(502, f"gitea: {e}")
|
||||
|
||||
|
||||
@router.get("/{owner}/{repo}/branches")
|
||||
async def branches(owner: str, repo: str, user=Depends(require_auth)):
|
||||
try:
|
||||
return await gitea.list_branches(f"{owner}/{repo}")
|
||||
except Exception as e:
|
||||
raise HTTPException(502, f"gitea: {e}")
|
||||
|
||||
|
||||
@router.get("/{owner}/{repo}/commits")
|
||||
async def commits(owner: str, repo: str, branch: str = Query("main"), user=Depends(require_auth)):
|
||||
try:
|
||||
return await gitea.list_commits(f"{owner}/{repo}", branch)
|
||||
except Exception as e:
|
||||
raise HTTPException(502, f"gitea: {e}")
|
||||
|
||||
|
||||
@router.get("/{owner}/{repo}/file")
|
||||
async def file_raw(owner: str, repo: str, ref: str, path: str, user=Depends(require_auth)):
|
||||
try:
|
||||
raw = await gitea.get_file_raw(f"{owner}/{repo}", ref, path)
|
||||
return {"content": raw.decode("utf-8", errors="replace"), "size": len(raw)}
|
||||
except Exception as e:
|
||||
raise HTTPException(404, f"file not found: {e}")
|
||||
|
||||
|
||||
@router.get("/{owner}/{repo}/spec")
|
||||
async def spec(owner: str, repo: str, ref: str = Query("main"), user=Depends(require_auth)):
|
||||
s = await fetch_and_parse_spec(f"{owner}/{repo}", ref)
|
||||
if s is None:
|
||||
raise HTTPException(404, "model.yml not found at ref")
|
||||
return s
|
||||
89
ml/routers/results.py
Normal file
89
ml/routers/results.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""API /api/results — lista trainings/tests + compare multi-training."""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from core import db, influx_client
|
||||
from core.auth import require_auth
|
||||
from core.config import settings
|
||||
|
||||
router = APIRouter(prefix="/api/results", tags=["results"])
|
||||
|
||||
|
||||
def _row(r):
|
||||
if r is None:
|
||||
return None
|
||||
d = dict(r)
|
||||
for k in ("queued_at", "started_at", "finished_at", "started_at", "ended_at"):
|
||||
if d.get(k) is not None and hasattr(d[k], "isoformat"):
|
||||
d[k] = d[k].isoformat()
|
||||
return d
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_results(
|
||||
model_id: Optional[str] = Query(None),
|
||||
user=Depends(require_auth),
|
||||
):
|
||||
where = []
|
||||
args: list = []
|
||||
if model_id:
|
||||
args.append(uuid.UUID(model_id))
|
||||
where.append(f"model_id = ${len(args)}")
|
||||
sql = "SELECT * FROM trainings"
|
||||
if where:
|
||||
sql += " WHERE " + " AND ".join(where)
|
||||
sql += " ORDER BY finished_at DESC NULLS LAST, queued_at DESC LIMIT 200"
|
||||
rows = await db.fetch(sql, *args)
|
||||
return {"count": len(rows), "trainings": [_row(r) for r in rows]}
|
||||
|
||||
|
||||
@router.get("/{training_id}")
|
||||
async def get_result(training_id: str, user=Depends(require_auth)):
|
||||
row = await db.fetchrow("SELECT * FROM trainings WHERE id = $1", uuid.UUID(training_id))
|
||||
if not row:
|
||||
raise HTTPException(404, "not found")
|
||||
# timeseries via Influx: loss per iter + cpu/mem
|
||||
flux = (
|
||||
f'from(bucket:"{settings.influx_bucket}") '
|
||||
f'|> range(start:-90d) '
|
||||
f'|> filter(fn: (r) => r._measurement == "ml_training" and r.training_id == "{training_id}")'
|
||||
)
|
||||
try:
|
||||
ts = await influx_client.query_flux(flux)
|
||||
except Exception:
|
||||
ts = []
|
||||
return {"training": _row(row), "timeseries": ts}
|
||||
|
||||
|
||||
@router.get("/compare")
|
||||
async def compare(
|
||||
trainings: str = Query(..., description="comma-separated training IDs"),
|
||||
user=Depends(require_auth),
|
||||
):
|
||||
ids = [s.strip() for s in trainings.split(",") if s.strip()]
|
||||
if len(ids) < 2:
|
||||
raise HTTPException(400, "at least 2 training IDs required")
|
||||
out = []
|
||||
for tid in ids:
|
||||
try:
|
||||
tid_uuid = uuid.UUID(tid)
|
||||
except ValueError:
|
||||
continue
|
||||
row = await db.fetchrow("SELECT * FROM trainings WHERE id = $1", tid_uuid)
|
||||
if not row:
|
||||
continue
|
||||
flux = (
|
||||
f'from(bucket:"{settings.influx_bucket}") '
|
||||
f'|> range(start:-90d) '
|
||||
f'|> filter(fn: (r) => r._measurement == "ml_training" and r.training_id == "{tid}")'
|
||||
)
|
||||
try:
|
||||
ts = await influx_client.query_flux(flux)
|
||||
except Exception:
|
||||
ts = []
|
||||
out.append({"training": _row(row), "timeseries": ts})
|
||||
return {"results": out}
|
||||
109
ml/routers/tests.py
Normal file
109
ml/routers/tests.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""API /api/tests — sessioni di test su training esistente (max 2 utenti simultanei)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from core import api_client, db, minio_client
|
||||
from core.auth import require_auth
|
||||
from core.docker_runner import run_test_once
|
||||
|
||||
router = APIRouter(prefix="/api/tests", tags=["tests"])
|
||||
|
||||
|
||||
def _row(r):
|
||||
if r is None:
|
||||
return None
|
||||
d = dict(r)
|
||||
for k in ("started_at", "ended_at"):
|
||||
if d.get(k) is not None and hasattr(d[k], "isoformat"):
|
||||
d[k] = d[k].isoformat()
|
||||
return d
|
||||
|
||||
|
||||
@router.post("/sessions", status_code=201)
|
||||
async def start_session(body: dict, user=Depends(require_auth)):
|
||||
training_id = body.get("training_id")
|
||||
if not training_id:
|
||||
raise HTTPException(400, "training_id required")
|
||||
|
||||
tr = await db.fetchrow(
|
||||
"SELECT id, status FROM trainings WHERE id = $1", uuid.UUID(training_id)
|
||||
)
|
||||
if not tr:
|
||||
raise HTTPException(404, "training not found")
|
||||
if tr["status"] != "succeeded":
|
||||
raise HTTPException(409, "training not completed")
|
||||
|
||||
sid = str(uuid.uuid4())
|
||||
try:
|
||||
await api_client.page_connect("test", user.get("username") or "unknown", sid)
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 429:
|
||||
raise HTTPException(429, "test slots full (max 2 users)")
|
||||
raise HTTPException(502, f"api: {e}")
|
||||
|
||||
row = await db.fetchrow(
|
||||
"INSERT INTO tests (id, training_id, user_id) VALUES ($1,$2,$3) RETURNING *",
|
||||
uuid.UUID(sid),
|
||||
uuid.UUID(training_id),
|
||||
user.get("username") or "unknown",
|
||||
)
|
||||
return _row(row)
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/ping")
|
||||
async def ping_session(session_id: str, user=Depends(require_auth)):
|
||||
try:
|
||||
await api_client.page_ping(session_id)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise HTTPException(e.response.status_code, e.response.text)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/runs", status_code=201)
|
||||
async def run_test(session_id: str, body: dict, user=Depends(require_auth)):
|
||||
row = await db.fetchrow("SELECT * FROM tests WHERE id = $1", uuid.UUID(session_id))
|
||||
if not row:
|
||||
raise HTTPException(404, "session not found")
|
||||
|
||||
inputs = body.get("inputs") or {}
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
result = await run_test_once(str(row["training_id"]), inputs)
|
||||
except Exception as e:
|
||||
raise HTTPException(500, f"test run failed: {e}")
|
||||
dt_ms = int((time.monotonic() - t0) * 1000)
|
||||
|
||||
run = {
|
||||
"inputs": inputs,
|
||||
"outputs": result.get("outputs", {}),
|
||||
"duration_ms": dt_ms,
|
||||
"cpu_peak": result.get("cpu_peak"),
|
||||
"mem_peak_mb": result.get("mem_peak_mb"),
|
||||
"ts": time.time(),
|
||||
}
|
||||
await db.execute(
|
||||
"UPDATE tests SET runs = runs || $1::jsonb WHERE id = $2",
|
||||
json.dumps([run]),
|
||||
uuid.UUID(session_id),
|
||||
)
|
||||
return run
|
||||
|
||||
|
||||
@router.delete("/sessions/{session_id}", status_code=204)
|
||||
async def end_session(session_id: str, user=Depends(require_auth)):
|
||||
await db.execute(
|
||||
"UPDATE tests SET ended_at = NOW() WHERE id = $1 AND ended_at IS NULL",
|
||||
uuid.UUID(session_id),
|
||||
)
|
||||
try:
|
||||
await api_client.page_disconnect(session_id)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
129
ml/routers/trainings.py
Normal file
129
ml/routers/trainings.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""API /api/trainings — enqueue, list, get, artifacts."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from core import db, minio_client, redis_client, api_client
|
||||
from core.auth import require_auth
|
||||
|
||||
router = APIRouter(prefix="/api/trainings", tags=["trainings"])
|
||||
|
||||
|
||||
def _row(r) -> Optional[dict]:
|
||||
if r is None:
|
||||
return None
|
||||
d = dict(r)
|
||||
for k in ("queued_at", "started_at", "finished_at"):
|
||||
if d.get(k) is not None and hasattr(d[k], "isoformat"):
|
||||
d[k] = d[k].isoformat()
|
||||
return d
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_trainings(
|
||||
model_id: Optional[str] = Query(None),
|
||||
status: Optional[str] = Query(None),
|
||||
limit: int = Query(100, le=500),
|
||||
user=Depends(require_auth),
|
||||
):
|
||||
where = []
|
||||
args: list = []
|
||||
if model_id:
|
||||
args.append(uuid.UUID(model_id))
|
||||
where.append(f"model_id = ${len(args)}")
|
||||
if status:
|
||||
args.append(status)
|
||||
where.append(f"status = ${len(args)}")
|
||||
sql = "SELECT * FROM trainings"
|
||||
if where:
|
||||
sql += " WHERE " + " AND ".join(where)
|
||||
args.append(limit)
|
||||
sql += f" ORDER BY queued_at DESC LIMIT ${len(args)}"
|
||||
rows = await db.fetch(sql, *args)
|
||||
return {"count": len(rows), "trainings": [_row(r) for r in rows]}
|
||||
|
||||
|
||||
@router.post("", status_code=202)
|
||||
async def enqueue_training(body: dict, user=Depends(require_auth)):
|
||||
for k in ("model_id", "version", "patch", "dataset_id"):
|
||||
if not body.get(k):
|
||||
raise HTTPException(400, f"missing field: {k}")
|
||||
|
||||
model_row = await db.fetchrow("SELECT * FROM models WHERE id = $1", uuid.UUID(body["model_id"]))
|
||||
if not model_row:
|
||||
raise HTTPException(404, "model not found")
|
||||
|
||||
ds_row = await db.fetchrow("SELECT id FROM datasets WHERE id = $1", uuid.UUID(body["dataset_id"]))
|
||||
if not ds_row:
|
||||
raise HTTPException(404, "dataset not found")
|
||||
|
||||
try:
|
||||
training_row = await db.fetchrow(
|
||||
"""
|
||||
INSERT INTO trainings (model_id, version, patch, dataset_id, started_by, status)
|
||||
VALUES ($1,$2,$3,$4,$5,'queued')
|
||||
RETURNING *
|
||||
""",
|
||||
uuid.UUID(body["model_id"]),
|
||||
body["version"],
|
||||
body["patch"],
|
||||
uuid.UUID(body["dataset_id"]),
|
||||
user.get("username") or "unknown",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(409, f"training already exists or invalid: {e}")
|
||||
|
||||
training_id = str(training_row["id"])
|
||||
|
||||
# crea job lato api-service (cross-service registry)
|
||||
try:
|
||||
await api_client.create_job(
|
||||
"train",
|
||||
created_by=user.get("username") or "unknown",
|
||||
payload={
|
||||
"training_id": training_id,
|
||||
"model_id": body["model_id"],
|
||||
"version": body["version"],
|
||||
"patch": body["patch"],
|
||||
"dataset_id": body["dataset_id"],
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
# non-fatale: il worker locale può comunque procedere; logghiamo e continuiamo
|
||||
import logging
|
||||
logging.warning("create_job failed: %s", e)
|
||||
|
||||
# enqueue in Redis (il worker locale lo raccoglie)
|
||||
await redis_client.client().lpush("ml:queue:train", training_id)
|
||||
await redis_client.client().hset(
|
||||
f"ml:train:{training_id}",
|
||||
mapping={"status": "queued", "progress": "0", "message": "queued"},
|
||||
)
|
||||
await redis_client.client().expire(f"ml:train:{training_id}", 48 * 3600)
|
||||
|
||||
return _row(training_row)
|
||||
|
||||
|
||||
@router.get("/{training_id}")
|
||||
async def get_training(training_id: str, user=Depends(require_auth)):
|
||||
row = await db.fetchrow("SELECT * FROM trainings WHERE id = $1", uuid.UUID(training_id))
|
||||
if not row:
|
||||
raise HTTPException(404, "not found")
|
||||
return _row(row)
|
||||
|
||||
|
||||
@router.get("/{training_id}/artifacts")
|
||||
async def list_artifacts(training_id: str, user=Depends(require_auth)):
|
||||
row = await db.fetchrow(
|
||||
"SELECT artifacts_prefix FROM trainings WHERE id = $1", uuid.UUID(training_id)
|
||||
)
|
||||
if not row or not row["artifacts_prefix"]:
|
||||
raise HTTPException(404, "no artifacts")
|
||||
objs = minio_client.list_prefix(row["artifacts_prefix"] + "/")
|
||||
for o in objs:
|
||||
o["url"] = minio_client.presigned_get(o["name"], 3600)
|
||||
return objs
|
||||
64
ml/routers/trainings_stream.py
Normal file
64
ml/routers/trainings_stream.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""SSE endpoint per live progress del training.
|
||||
|
||||
GET /api/trainings/{id}/events
|
||||
Streamma eventi dal Redis stream `ml:train:{id}:events` via Server-Sent Events.
|
||||
Termina quando lo stato del training è terminale (succeeded/failed/cancelled).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from core import db, redis_client
|
||||
from core.auth import require_auth
|
||||
|
||||
router = APIRouter(prefix="/api/trainings", tags=["trainings-sse"])
|
||||
|
||||
_TERMINAL = {"succeeded", "failed", "cancelled"}
|
||||
|
||||
|
||||
@router.get("/{training_id}/events")
|
||||
async def training_events(training_id: str, user=Depends(require_auth)):
|
||||
# verifica esistenza
|
||||
row = await db.fetchrow("SELECT status FROM trainings WHERE id = $1", uuid.UUID(training_id))
|
||||
if not row:
|
||||
raise HTTPException(404, "not found")
|
||||
|
||||
stream_key = f"ml:train:{training_id}:events"
|
||||
status_key = f"ml:train:{training_id}"
|
||||
|
||||
async def gen():
|
||||
last_id = "0-0"
|
||||
r = redis_client.client()
|
||||
while True:
|
||||
try:
|
||||
# XREAD block 5s per non tenere la connessione idle troppo a lungo
|
||||
resp = await r.xread({stream_key: last_id}, count=50, block=5000)
|
||||
except Exception as e:
|
||||
yield {"event": "error", "data": json.dumps({"error": str(e)})}
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
if resp:
|
||||
for _stream, entries in resp:
|
||||
for entry_id, fields in entries:
|
||||
last_id = entry_id
|
||||
yield {"event": "message", "id": entry_id, "data": json.dumps(fields)}
|
||||
|
||||
# controlla stato terminale
|
||||
state = await r.hget(status_key, "status")
|
||||
if not state:
|
||||
# fallback su db se redis scaduto
|
||||
db_row = await db.fetchrow(
|
||||
"SELECT status FROM trainings WHERE id = $1", uuid.UUID(training_id)
|
||||
)
|
||||
state = db_row["status"] if db_row else "unknown"
|
||||
if state in _TERMINAL:
|
||||
yield {"event": "end", "data": json.dumps({"status": state})}
|
||||
return
|
||||
|
||||
return EventSourceResponse(gen())
|
||||
Reference in New Issue
Block a user