This commit is contained in:
2026-06-12 11:42:26 -04:00
parent 522ddc74e2
commit 2d4d377051
8 changed files with 390 additions and 38 deletions

View File

@@ -164,12 +164,15 @@ Uploads use `s3.bucket` and `s3.data_prefix` from `config.yaml`. File uploads de
```
qc-cli train start Submit a SageMaker training job
qc-cli train status [job-name] Show job status; defaults to the last submitted job
qc-cli train wait [job-name] Wait for completion and finalize MLflow tracking
qc-cli train list List recent training jobs
qc-cli train list --limit 3 Show a custom number of recent jobs
```
`train start` uses `s3://<bucket>/<data_prefix>/` as the training channel and writes outputs under `s3://<bucket>/<model_prefix>/`. If `sagemaker.training.source_dir` is set, the CLI packages that directory, uploads it beside the job output prefix, and passes `sagemaker_program`/`sagemaker_submit_directory` to the SageMaker container.
`train wait` checks SageMaker every 30 seconds by default. Use `--poll-interval <seconds>` to choose another positive interval. Stopping the local command does not stop the SageMaker job.
The expected output artifact is SageMakers `model.tar.gz`, normally containing the trained model file your container writes to `/opt/ml/model`.
### `ai-hub`
@@ -216,7 +219,7 @@ The CLI uses neutral experiment naming for trained artifacts and reserves releas
Current behavior:
1. `qc-cli train start` submits a SageMaker training job.
2. `qc-cli train status` finalizes the MLflow run after the job reaches a terminal state.
2. `qc-cli train status` or `qc-cli train wait` finalizes the MLflow run after the job reaches a terminal state. `train wait` blocks and polls every 30 seconds by default.
3. If the job completed and `mlflow.register_trained_models` is enabled, the SageMaker `model.tar.gz` is registered as a new MLflow model version with:
- `qc_cli.stage=experiment`
- `qc_cli.artifact_kind=trained_source`
@@ -224,6 +227,20 @@ Current behavior:
4. The MLflow alias `experiment-latest` points at the most recently registered experiment version.
5. AI Hub upload commands create deployable derived artifacts from a trained-source experiment or local ONNX model.
Training scripts can include a `training_metrics.json` file in the SageMaker model directory. During finalization, the CLI logs its ordered metrics to the associated MLflow run using each epoch as the MLflow step and stores the JSON as a run artifact:
```json
{
"schema_version": 1,
"steps": [
{"step": 0, "metrics": {"val.precision": 0.72, "val.recall": 0.68}}
],
"summary": {"summary.best_epoch": 0}
}
```
Metric names must be non-empty strings, values must be finite numbers, and steps must be non-negative, unique, and strictly increasing. Missing or malformed metrics produce a warning but do not block model registration.
Future release aliases such as `v1` or `production` can point at a selected deployable artifact.
Example future metadata:

View File

@@ -153,6 +153,14 @@ Or pass the job name explicitly:
qc-cli train status qc-cli-YYYYMMDD-HHMMSS
```
To wait for completion and automatically import metrics and register the model, run:
```bash
qc-cli train wait
```
The default polling interval is 30 seconds. It can be changed with `--poll-interval <seconds>`.
## SageMaker Outputs
When the job completes, SageMaker packages the files written under `/opt/ml/model` into `model.tar.gz`.
@@ -163,10 +171,13 @@ This example writes:
best.pt
model.onnx
metrics.json
training_metrics.json
```
The archive is stored under the configured `s3.model_prefix`.
During MLflow finalization, `training_metrics.json` provides per-epoch training and validation losses, precision, recall, mAP@0.50, mAP@0.50:0.95, and learning rates. For object detection, mAP and precision/recall are more meaningful than classification accuracy when assessing model quality.
## 6. Configure Qualcomm AI Hub
Authenticate with Qualcomm AI Hub:

View File

@@ -12,6 +12,7 @@ from typing import Any
import yaml
from sanitize_onnx import sanitize_onnx
from training_metrics import write_training_metrics
from ultralytics import YOLO # type: ignore[reportMissingImports]
@@ -101,6 +102,7 @@ def main() -> None:
if not trained_weights.exists():
raise FileNotFoundError(f"Could not find trained weights in {save_dir / 'weights'}")
write_training_metrics(save_dir / "results.csv", model_dir / "training_metrics.json")
copy_if_exists(trained_weights, model_dir / "best.pt")
trained_model = YOLO(str(trained_weights))
onnx_path = Path(trained_model.export(format="onnx", imgsz=args.imgsz))

View File

@@ -0,0 +1,82 @@
import csv
import json
import math
import re
from pathlib import Path
from typing import Any
METRIC_NAMES = {
"metrics/precision(B)": "val.precision",
"metrics/recall(B)": "val.recall",
"metrics/mAP50(B)": "val.map50",
"metrics/mAP50-95(B)": "val.map50_95",
"train/box_loss": "train.box_loss",
"train/cls_loss": "train.cls_loss",
"train/dfl_loss": "train.dfl_loss",
"val/box_loss": "val.box_loss",
"val/cls_loss": "val.cls_loss",
"val/dfl_loss": "val.dfl_loss",
"time": "train.elapsed_seconds",
}
def write_training_metrics(results_csv: Path, destination: Path) -> None:
steps = _read_metric_steps(results_csv)
summary = _build_summary(steps)
payload = {
"schema_version": 1,
"steps": steps,
"summary": summary,
}
destination.write_text(json.dumps(payload, indent=2), encoding="utf-8")
print(f"Saved {destination}")
def _read_metric_steps(results_csv: Path) -> list[dict[str, Any]]:
if not results_csv.is_file():
raise FileNotFoundError(f"Could not find Ultralytics metrics history: {results_csv}")
steps: list[dict[str, Any]] = []
with results_csv.open(newline="", encoding="utf-8") as csv_file:
for row_index, raw_row in enumerate(csv.DictReader(csv_file)):
row = {str(key).strip(): value for key, value in raw_row.items()}
raw_epoch = row.pop("epoch", row_index)
step = int(float(raw_epoch))
metrics: dict[str, float] = {}
for source_name, raw_value in row.items():
if raw_value is None or not raw_value.strip():
continue
try:
value = float(raw_value)
except ValueError:
continue
if math.isfinite(value):
metrics[METRIC_NAMES.get(source_name, _normalize_metric_name(source_name))] = value
steps.append({"step": step, "metrics": metrics})
return steps
def _build_summary(steps: list[dict[str, Any]]) -> dict[str, float]:
if not steps:
return {}
summary: dict[str, float] = {}
final_step = steps[-1]
summary["summary.final_epoch"] = float(final_step["step"])
for name, value in final_step["metrics"].items():
summary[f"summary.final.{name}"] = value
scored_steps = [step for step in steps if "val.map50_95" in step["metrics"]]
if scored_steps:
best_step = max(scored_steps, key=lambda step: step["metrics"]["val.map50_95"])
summary["summary.best_epoch"] = float(best_step["step"])
summary["summary.best_val.map50_95"] = best_step["metrics"]["val.map50_95"]
if "val.map50" in best_step["metrics"]:
summary["summary.best_val.map50"] = best_step["metrics"]["val.map50"]
return summary
def _normalize_metric_name(name: str) -> str:
normalized = name.replace("/", ".")
normalized = re.sub(r"[^A-Za-z0-9_.-]+", "_", normalized)
return normalized.strip("._") or "unnamed"

View File

@@ -1,3 +1,4 @@
import time
from datetime import datetime
from pathlib import Path
@@ -21,6 +22,8 @@ _STATUS_COLOR = {
"Stopping": "yellow",
"Stopped": "dim",
}
_TERMINAL_STATUSES = {"Completed", "Failed", "Stopped"}
DEFAULT_POLL_INTERVAL_SECONDS = 30
def _tracker(cfg):
@@ -48,6 +51,57 @@ def _sagemaker_role_arn(config_path: str, cfg: Config) -> str:
raise RuntimeError("SageMaker role not found in infra state. Run 'qc-cli infra setup' first.")
def _print_training_status(status: sm_ops.TrainingJobStatus) -> None:
color = _STATUS_COLOR.get(status.status, "white")
CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]")
CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]")
if status.created:
CONSOLE.print(f"Created: {status.created}")
if status.model_artifacts:
CONSOLE.print(f"Artifacts: {status.model_artifacts}")
if status.failure_reason:
CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]")
def _finalize_terminal_job(
*,
config_path: str,
cfg: Config,
status: sm_ops.TrainingJobStatus,
command: str,
) -> None:
if status.status not in _TERMINAL_STATUSES:
return
st = state_ops.store(config_path)
job_state = st.get_training_job(status.name)
run_id = job_state.get("mlflow_run_id")
if not run_id or job_state.get("mlflow_finalized_status"):
return
tracker = _tracker(cfg)
result = tracker.finalize_training_run(
run_id=str(run_id),
training_job_status=status,
region=cfg.aws.region,
profile=cfg.aws.profile,
command=command,
)
updates = {"mlflow_finalized_status": status.status}
if result.registered_model_version:
updates["registered_model_version"] = result.registered_model_version
st.update_training_job(status.name, **updates)
for warning in result.warnings:
CONSOLE.print(f"[yellow]MLflow metrics warning: {warning}[/yellow]")
if result.registered_model_version:
st.set_latest_experiment_model_version(result.registered_model_version)
CONSOLE.print(
f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] "
"([cyan]experiment-latest[/cyan])"
)
@app.command()
def start(config: str = CONFIG_OPT) -> None:
"""Submit a SageMaker training job."""
@@ -123,37 +177,65 @@ def status(
raise typer.Exit(1)
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
color = _STATUS_COLOR.get(status.status, "white")
CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]")
CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]")
if status.created:
CONSOLE.print(f"Created: {status.created}")
if status.model_artifacts:
CONSOLE.print(f"Artifacts: {status.model_artifacts}")
if status.failure_reason:
CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]")
_print_training_status(status)
_finalize_terminal_job(config_path=config, cfg=cfg, status=status, command="train status")
job_state = st.get_training_job(job_name)
run_id = job_state.get("mlflow_run_id")
already_registered = job_state.get("registered_model_version")
if run_id and not already_registered and status.status in {"Completed", "Failed", "Stopped"}:
tracker = _tracker(cfg)
version = tracker.finalize_training_run(
run_id=str(run_id),
training_job_status=status,
)
updates = {"mlflow_finalized_status": status.status}
if version:
updates["registered_model_version"] = version
st.update_training_job(job_name, **updates)
if version:
st.set_latest_experiment_model_version(version)
CONSOLE.print(f"MLflow model version: [cyan]{version}[/cyan] ([cyan]experiment-latest[/cyan])")
if run_id and cfg.mlflow.mode is not MlflowMode.disabled:
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
@app.command()
def wait(
job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"),
poll_interval: int = typer.Option(
DEFAULT_POLL_INTERVAL_SECONDS,
"--poll-interval",
min=1,
help="Seconds between SageMaker status checks",
),
config: str = CONFIG_OPT,
) -> None:
"""Wait for a training job and finalize its MLflow run."""
cfg = load_cfg(config)
st = state_ops.store(config)
if not job_name:
job_name = st.get_last_training_job()
if not job_name:
CONSOLE.print(
"[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]"
)
raise typer.Exit(1)
previous_status: str | None = None
try:
while True:
training_status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
if training_status.status != previous_status:
color = _STATUS_COLOR.get(training_status.status, "white")
CONSOLE.print(
f"Job [cyan]{training_status.name}[/cyan]: "
f"[{color}]{training_status.status}[/{color}]"
)
previous_status = training_status.status
if training_status.status in _TERMINAL_STATUSES:
_print_training_status(training_status)
_finalize_terminal_job(
config_path=config,
cfg=cfg,
status=training_status,
command="train wait",
)
job_state = st.get_training_job(job_name)
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
return
time.sleep(poll_interval)
except KeyboardInterrupt:
CONSOLE.print("[yellow]Stopped waiting. The SageMaker training job is still running.[/yellow]")
raise typer.Exit(130)
@app.command(name="list")
def list_jobs(
limit: int = typer.Option(10, "--limit", "-n", help="Number of jobs to show"),

View File

@@ -1,3 +1,3 @@
from src.tracking.mlflow import MlflowTracker, NoopTracker, Tracker
from src.tracking.mlflow import FinalizeResult, MlflowTracker, NoopTracker, Tracker
__all__ = ["MlflowTracker", "NoopTracker", "Tracker"]
__all__ = ["FinalizeResult", "MlflowTracker", "NoopTracker", "Tracker"]

93
src/tracking/metrics.py Normal file
View File

@@ -0,0 +1,93 @@
import json
import math
import tarfile
from dataclasses import dataclass
from pathlib import PurePosixPath
from typing import Any
METRICS_ARTIFACT_NAME = "training_metrics.json"
METRICS_SCHEMA_VERSION = 1
MAX_METRICS_ARTIFACT_BYTES = 10 * 1024 * 1024
@dataclass(frozen=True)
class MetricStep:
step: int
metrics: dict[str, float]
@dataclass(frozen=True)
class TrainingMetrics:
steps: list[MetricStep]
summary: dict[str, float]
raw: dict[str, Any]
def parse_training_metrics(data: bytes) -> TrainingMetrics:
try:
value = json.loads(data)
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
raise ValueError(f"{METRICS_ARTIFACT_NAME} is not valid JSON") from exc
if not isinstance(value, dict):
raise ValueError(f"{METRICS_ARTIFACT_NAME} must contain a JSON object")
if value.get("schema_version") != METRICS_SCHEMA_VERSION:
raise ValueError(f"Unsupported training metrics schema version: {value.get('schema_version')!r}")
raw_steps = value.get("steps")
if not isinstance(raw_steps, list):
raise ValueError("training metrics 'steps' must be a list")
steps: list[MetricStep] = []
previous_step: int | None = None
for index, raw_step in enumerate(raw_steps):
if not isinstance(raw_step, dict):
raise ValueError(f"training metrics step {index} must be an object")
step = raw_step.get("step")
if isinstance(step, bool) or not isinstance(step, int) or step < 0:
raise ValueError(f"training metrics step {index} has an invalid 'step'")
if previous_step is not None and step <= previous_step:
raise ValueError("training metrics steps must be unique and strictly increasing")
metrics = _numeric_metrics(raw_step.get("metrics"), f"training metrics step {step}")
steps.append(MetricStep(step=step, metrics=metrics))
previous_step = step
summary = _numeric_metrics(value.get("summary", {}), "training metrics summary")
return TrainingMetrics(steps=steps, summary=summary, raw=value)
def read_training_metrics_from_tar(archive_path: str) -> bytes | None:
with tarfile.open(archive_path, mode="r:*") as archive:
matches = [
member
for member in archive.getmembers()
if member.isfile() and PurePosixPath(member.name).name == METRICS_ARTIFACT_NAME
]
if not matches:
return None
if len(matches) > 1:
raise ValueError(f"Model archive contains multiple {METRICS_ARTIFACT_NAME} files")
if matches[0].size > MAX_METRICS_ARTIFACT_BYTES:
raise ValueError(
f"{METRICS_ARTIFACT_NAME} exceeds the {MAX_METRICS_ARTIFACT_BYTES}-byte size limit"
)
extracted = archive.extractfile(matches[0])
if extracted is None:
raise ValueError(f"Could not read {METRICS_ARTIFACT_NAME} from model archive")
return extracted.read()
def _numeric_metrics(value: Any, context: str) -> dict[str, float]:
if not isinstance(value, dict):
raise ValueError(f"{context} 'metrics' must be an object")
metrics: dict[str, float] = {}
for raw_name, raw_value in value.items():
if not isinstance(raw_name, str) or not raw_name:
raise ValueError(f"{context} contains an invalid metric name")
if isinstance(raw_value, bool) or not isinstance(raw_value, int | float):
raise ValueError(f"{context} metric '{raw_name}' must be numeric")
metric_value = float(raw_value)
if not math.isfinite(metric_value):
raise ValueError(f"{context} metric '{raw_name}' must be finite")
metrics[raw_name] = metric_value
return metrics

View File

@@ -1,4 +1,5 @@
import os
import tempfile
from dataclasses import dataclass
from typing import Any, Protocol
@@ -6,13 +7,29 @@ import mlflow
from mlflow.tracking import MlflowClient
from src.aws import mlflow as aws_mlflow
from src.aws import s3
from src.config import Config, MlflowMode
from src.tracking.metrics import METRICS_ARTIFACT_NAME, parse_training_metrics, read_training_metrics_from_tar
@dataclass(frozen=True)
class FinalizeResult:
registered_model_version: str | None = None
warnings: tuple[str, ...] = ()
class Tracker(Protocol):
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: ...
def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None: ...
def finalize_training_run(
self,
*,
run_id: str | None,
training_job_status: Any,
region: str,
profile: str,
command: str,
) -> FinalizeResult: ...
@dataclass(frozen=True)
@@ -20,8 +37,16 @@ class NoopTracker:
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
return None
def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None:
return None
def finalize_training_run(
self,
*,
run_id: str | None,
training_job_status: Any,
region: str,
profile: str,
command: str,
) -> FinalizeResult:
return FinalizeResult()
@dataclass(frozen=True)
@@ -88,10 +113,19 @@ class MlflowTracker:
mlflow.end_run()
return run_id
def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None:
def finalize_training_run(
self,
*,
run_id: str | None,
training_job_status: Any,
region: str,
profile: str,
command: str,
) -> FinalizeResult:
if not run_id:
return None
return FinalizeResult()
warnings: list[str] = []
with mlflow.start_run(run_id=run_id):
self._log_params(
{
@@ -103,14 +137,22 @@ class MlflowTracker:
}
)
self._log_final_metrics(training_job_status.raw)
mlflow.set_tag("qc_cli.command", "train status")
if training_job_status.status == "Completed" and training_job_status.model_artifacts:
warnings.extend(
self._log_training_metrics(
training_job_status.model_artifacts,
region=region,
profile=profile,
)
)
mlflow.set_tag("qc_cli.command", command)
if training_job_status.status != "Completed" or not training_job_status.model_artifacts:
mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status)
return None
return FinalizeResult(warnings=tuple(warnings))
if not self.register_trained_models:
return None
return FinalizeResult(warnings=tuple(warnings))
client = MlflowClient()
self._ensure_registered_model(client, self.registered_model_name)
@@ -129,7 +171,7 @@ class MlflowTracker:
client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number)
mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name)
mlflow.set_tag("qc_cli.registered_model_version", version_number)
return version_number
return FinalizeResult(registered_model_version=version_number, warnings=tuple(warnings))
def _log_params(self, params: dict[str, Any]) -> None:
cleaned = {key: str(value) for key, value in params.items() if value is not None}
@@ -146,6 +188,29 @@ class MlflowTracker:
if metrics:
mlflow.log_metrics(metrics)
def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> list[str]:
try:
with tempfile.TemporaryDirectory(prefix="qc-cli-metrics-") as temp_dir:
archive_path = s3.download_file(
region,
profile,
model_artifacts,
os.path.join(temp_dir, "model.tar.gz"),
)
metrics_data = read_training_metrics_from_tar(archive_path)
if metrics_data is None:
return [f"No {METRICS_ARTIFACT_NAME} found in the SageMaker model artifact."]
metrics = parse_training_metrics(metrics_data)
for metric_step in metrics.steps:
if metric_step.metrics:
mlflow.log_metrics(metric_step.metrics, step=metric_step.step)
if metrics.summary:
mlflow.log_metrics(metrics.summary)
mlflow.log_dict(metrics.raw, METRICS_ARTIFACT_NAME)
except Exception as exc:
return [f"Could not import training metrics: {exc}"]
return []
def _ensure_registered_model(self, client: MlflowClient, name: str) -> None:
try:
client.get_registered_model(name)