"""API /api/trainings — enqueue, list, get, artifacts.""" from __future__ import annotations import json import uuid from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query from core import db, minio_client, redis_client, api_client from core.auth import require_auth router = APIRouter(prefix="/api/trainings", tags=["trainings"]) def _row(r) -> Optional[dict]: if r is None: return None d = dict(r) for k in ("queued_at", "started_at", "finished_at"): if d.get(k) is not None and hasattr(d[k], "isoformat"): d[k] = d[k].isoformat() return d @router.get("") async def list_trainings( model_id: Optional[str] = Query(None), status: Optional[str] = Query(None), limit: int = Query(100, le=500), user=Depends(require_auth), ): where = [] args: list = [] if model_id: args.append(uuid.UUID(model_id)) where.append(f"model_id = ${len(args)}") if status: args.append(status) where.append(f"status = ${len(args)}") sql = "SELECT * FROM trainings" if where: sql += " WHERE " + " AND ".join(where) args.append(limit) sql += f" ORDER BY queued_at DESC LIMIT ${len(args)}" rows = await db.fetch(sql, *args) return {"count": len(rows), "trainings": [_row(r) for r in rows]} @router.post("", status_code=202) async def enqueue_training(body: dict, user=Depends(require_auth)): for k in ("model_id", "version", "patch", "dataset_id"): if not body.get(k): raise HTTPException(400, f"missing field: {k}") model_row = await db.fetchrow("SELECT * FROM models WHERE id = $1", uuid.UUID(body["model_id"])) if not model_row: raise HTTPException(404, "model not found") ds_row = await db.fetchrow("SELECT id FROM datasets WHERE id = $1", uuid.UUID(body["dataset_id"])) if not ds_row: raise HTTPException(404, "dataset not found") try: training_row = await db.fetchrow( """ INSERT INTO trainings (model_id, version, patch, dataset_id, started_by, status) VALUES ($1,$2,$3,$4,$5,'queued') RETURNING * """, uuid.UUID(body["model_id"]), body["version"], body["patch"], uuid.UUID(body["dataset_id"]), user.get("username") or "unknown", ) except Exception as e: raise HTTPException(409, f"training already exists or invalid: {e}") training_id = str(training_row["id"]) # crea job lato api-service (cross-service registry) try: await api_client.create_job( "train", created_by=user.get("username") or "unknown", payload={ "training_id": training_id, "model_id": body["model_id"], "version": body["version"], "patch": body["patch"], "dataset_id": body["dataset_id"], }, ) except Exception as e: # non-fatale: il worker locale può comunque procedere; logghiamo e continuiamo import logging logging.warning("create_job failed: %s", e) # enqueue in Redis (il worker locale lo raccoglie) await redis_client.client().lpush("ml:queue:train", training_id) await redis_client.client().hset( f"ml:train:{training_id}", mapping={"status": "queued", "progress": "0", "message": "queued"}, ) await redis_client.client().expire(f"ml:train:{training_id}", 48 * 3600) return _row(training_row) @router.get("/{training_id}") async def get_training(training_id: str, user=Depends(require_auth)): row = await db.fetchrow("SELECT * FROM trainings WHERE id = $1", uuid.UUID(training_id)) if not row: raise HTTPException(404, "not found") return _row(row) @router.get("/{training_id}/artifacts") async def list_artifacts(training_id: str, user=Depends(require_auth)): row = await db.fetchrow( "SELECT artifacts_prefix FROM trainings WHERE id = $1", uuid.UUID(training_id) ) if not row or not row["artifacts_prefix"]: raise HTTPException(404, "no artifacts") objs = minio_client.list_prefix(row["artifacts_prefix"] + "/") for o in objs: o["url"] = minio_client.presigned_get(o["name"], 3600) return objs