feat: Add new API endpoints and HTML pages for ML model management
- Implemented HTML pages for datasets, models, training, testing, and results. - Created API endpoints for managing repositories, results, tests, and training sessions. - Added functionality for streaming training progress via Server-Sent Events (SSE). - Introduced a Dockerfile for the ML runner with necessary dependencies. - Developed an SDK for user code execution within the runner container. - Enhanced CSS styles for improved UI layout and navigation. - Established a layout template for consistent HTML structure across pages. - Added JavaScript for dynamic interactions on the models page. - Implemented WebSocket handling for real-time communication with kiosk devices and controllers. - Implemented model registration and management API at /api/models - Added Gitea proxy API for repository interactions at /api/repos - Created results API for listing and comparing training results at /api/results - Developed training management API for enqueueing and retrieving training jobs at /api/trainings - Introduced SSE endpoint for live training progress updates - Added HTML pages for models, datasets, and training management - Created a Dockerfile for the ML runner with necessary dependencies - Developed SDK for user code execution within the runner container - Enhanced CSS styles for improved UI/UX - Implemented WebSocket communication for real-time device and controller interactions in the kiosk system
This commit is contained in:
129
ml/routers/trainings.py
Normal file
129
ml/routers/trainings.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""API /api/trainings — enqueue, list, get, artifacts."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from core import db, minio_client, redis_client, api_client
|
||||
from core.auth import require_auth
|
||||
|
||||
router = APIRouter(prefix="/api/trainings", tags=["trainings"])
|
||||
|
||||
|
||||
def _row(r) -> Optional[dict]:
|
||||
if r is None:
|
||||
return None
|
||||
d = dict(r)
|
||||
for k in ("queued_at", "started_at", "finished_at"):
|
||||
if d.get(k) is not None and hasattr(d[k], "isoformat"):
|
||||
d[k] = d[k].isoformat()
|
||||
return d
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_trainings(
|
||||
model_id: Optional[str] = Query(None),
|
||||
status: Optional[str] = Query(None),
|
||||
limit: int = Query(100, le=500),
|
||||
user=Depends(require_auth),
|
||||
):
|
||||
where = []
|
||||
args: list = []
|
||||
if model_id:
|
||||
args.append(uuid.UUID(model_id))
|
||||
where.append(f"model_id = ${len(args)}")
|
||||
if status:
|
||||
args.append(status)
|
||||
where.append(f"status = ${len(args)}")
|
||||
sql = "SELECT * FROM trainings"
|
||||
if where:
|
||||
sql += " WHERE " + " AND ".join(where)
|
||||
args.append(limit)
|
||||
sql += f" ORDER BY queued_at DESC LIMIT ${len(args)}"
|
||||
rows = await db.fetch(sql, *args)
|
||||
return {"count": len(rows), "trainings": [_row(r) for r in rows]}
|
||||
|
||||
|
||||
@router.post("", status_code=202)
|
||||
async def enqueue_training(body: dict, user=Depends(require_auth)):
|
||||
for k in ("model_id", "version", "patch", "dataset_id"):
|
||||
if not body.get(k):
|
||||
raise HTTPException(400, f"missing field: {k}")
|
||||
|
||||
model_row = await db.fetchrow("SELECT * FROM models WHERE id = $1", uuid.UUID(body["model_id"]))
|
||||
if not model_row:
|
||||
raise HTTPException(404, "model not found")
|
||||
|
||||
ds_row = await db.fetchrow("SELECT id FROM datasets WHERE id = $1", uuid.UUID(body["dataset_id"]))
|
||||
if not ds_row:
|
||||
raise HTTPException(404, "dataset not found")
|
||||
|
||||
try:
|
||||
training_row = await db.fetchrow(
|
||||
"""
|
||||
INSERT INTO trainings (model_id, version, patch, dataset_id, started_by, status)
|
||||
VALUES ($1,$2,$3,$4,$5,'queued')
|
||||
RETURNING *
|
||||
""",
|
||||
uuid.UUID(body["model_id"]),
|
||||
body["version"],
|
||||
body["patch"],
|
||||
uuid.UUID(body["dataset_id"]),
|
||||
user.get("username") or "unknown",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(409, f"training already exists or invalid: {e}")
|
||||
|
||||
training_id = str(training_row["id"])
|
||||
|
||||
# crea job lato api-service (cross-service registry)
|
||||
try:
|
||||
await api_client.create_job(
|
||||
"train",
|
||||
created_by=user.get("username") or "unknown",
|
||||
payload={
|
||||
"training_id": training_id,
|
||||
"model_id": body["model_id"],
|
||||
"version": body["version"],
|
||||
"patch": body["patch"],
|
||||
"dataset_id": body["dataset_id"],
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
# non-fatale: il worker locale può comunque procedere; logghiamo e continuiamo
|
||||
import logging
|
||||
logging.warning("create_job failed: %s", e)
|
||||
|
||||
# enqueue in Redis (il worker locale lo raccoglie)
|
||||
await redis_client.client().lpush("ml:queue:train", training_id)
|
||||
await redis_client.client().hset(
|
||||
f"ml:train:{training_id}",
|
||||
mapping={"status": "queued", "progress": "0", "message": "queued"},
|
||||
)
|
||||
await redis_client.client().expire(f"ml:train:{training_id}", 48 * 3600)
|
||||
|
||||
return _row(training_row)
|
||||
|
||||
|
||||
@router.get("/{training_id}")
|
||||
async def get_training(training_id: str, user=Depends(require_auth)):
|
||||
row = await db.fetchrow("SELECT * FROM trainings WHERE id = $1", uuid.UUID(training_id))
|
||||
if not row:
|
||||
raise HTTPException(404, "not found")
|
||||
return _row(row)
|
||||
|
||||
|
||||
@router.get("/{training_id}/artifacts")
|
||||
async def list_artifacts(training_id: str, user=Depends(require_auth)):
|
||||
row = await db.fetchrow(
|
||||
"SELECT artifacts_prefix FROM trainings WHERE id = $1", uuid.UUID(training_id)
|
||||
)
|
||||
if not row or not row["artifacts_prefix"]:
|
||||
raise HTTPException(404, "no artifacts")
|
||||
objs = minio_client.list_prefix(row["artifacts_prefix"] + "/")
|
||||
for o in objs:
|
||||
o["url"] = minio_client.presigned_get(o["name"], 3600)
|
||||
return objs
|
||||
Reference in New Issue
Block a user