include-metrics-from-training #6
19
README.md
19
README.md
@@ -164,12 +164,15 @@ Uploads use `s3.bucket` and `s3.data_prefix` from `config.yaml`. File uploads de
|
||||
```
|
||||
qc-cli train start Submit a SageMaker training job
|
||||
qc-cli train status [job-name] Show job status; defaults to the last submitted job
|
||||
qc-cli train wait [job-name] Wait for completion and finalize MLflow tracking
|
||||
qc-cli train list List recent training jobs
|
||||
qc-cli train list --limit 3 Show a custom number of recent jobs
|
||||
```
|
||||
|
||||
`train start` uses `s3://<bucket>/<data_prefix>/` as the training channel and writes outputs under `s3://<bucket>/<model_prefix>/`. If `sagemaker.training.source_dir` is set, the CLI packages that directory, uploads it beside the job output prefix, and passes `sagemaker_program`/`sagemaker_submit_directory` to the SageMaker container.
|
||||
|
||||
`train wait` checks SageMaker every 30 seconds by default. Use `--poll-interval <seconds>` to choose another positive interval. Stopping the local command does not stop the SageMaker job.
|
||||
|
||||
The expected output artifact is SageMaker’s `model.tar.gz`, normally containing the trained model file your container writes to `/opt/ml/model`.
|
||||
|
||||
### `ai-hub`
|
||||
@@ -216,7 +219,7 @@ The CLI uses neutral experiment naming for trained artifacts and reserves releas
|
||||
Current behavior:
|
||||
|
||||
1. `qc-cli train start` submits a SageMaker training job.
|
||||
2. `qc-cli train status` finalizes the MLflow run after the job reaches a terminal state.
|
||||
2. `qc-cli train status` or `qc-cli train wait` finalizes the MLflow run after the job reaches a terminal state. `train wait` blocks and polls every 30 seconds by default.
|
||||
3. If the job completed and `mlflow.register_trained_models` is enabled, the SageMaker `model.tar.gz` is registered as a new MLflow model version with:
|
||||
- `qc_cli.stage=experiment`
|
||||
- `qc_cli.artifact_kind=trained_source`
|
||||
@@ -224,6 +227,20 @@ Current behavior:
|
||||
4. The MLflow alias `experiment-latest` points at the most recently registered experiment version.
|
||||
5. AI Hub upload commands create deployable derived artifacts from a trained-source experiment or local ONNX model.
|
||||
|
||||
Training scripts can include a `training_metrics.json` file in the SageMaker model directory. During finalization, the CLI logs its ordered metrics to the associated MLflow run using each epoch as the MLflow step and stores the JSON as a run artifact:
|
||||
|
||||
```json
|
||||
{
|
||||
"schema_version": 1,
|
||||
"steps": [
|
||||
{"step": 0, "metrics": {"val.precision": 0.72, "val.recall": 0.68}}
|
||||
],
|
||||
"summary": {"summary.best_epoch": 0}
|
||||
}
|
||||
```
|
||||
|
||||
Metric names must be non-empty strings, values must be finite numbers, and steps must be non-negative, unique, and strictly increasing. Missing or malformed metrics produce a warning but do not block model registration.
|
||||
|
||||
Future release aliases such as `v1` or `production` can point at a selected deployable artifact.
|
||||
|
||||
Example future metadata:
|
||||
|
||||
@@ -153,6 +153,14 @@ Or pass the job name explicitly:
|
||||
qc-cli train status qc-cli-YYYYMMDD-HHMMSS
|
||||
```
|
||||
|
||||
To wait for completion and automatically import metrics and register the model, run:
|
||||
|
||||
```bash
|
||||
qc-cli train wait
|
||||
```
|
||||
|
||||
The default polling interval is 30 seconds. It can be changed with `--poll-interval <seconds>`.
|
||||
|
||||
## SageMaker Outputs
|
||||
|
||||
When the job completes, SageMaker packages the files written under `/opt/ml/model` into `model.tar.gz`.
|
||||
@@ -163,10 +171,13 @@ This example writes:
|
||||
best.pt
|
||||
model.onnx
|
||||
metrics.json
|
||||
training_metrics.json
|
||||
```
|
||||
|
||||
The archive is stored under the configured `s3.model_prefix`.
|
||||
|
||||
During MLflow finalization, `training_metrics.json` provides per-epoch training and validation losses, precision, recall, mAP@0.50, mAP@0.50:0.95, and learning rates. For object detection, mAP and precision/recall are more meaningful than classification accuracy when assessing model quality.
|
||||
|
||||
## 6. Configure Qualcomm AI Hub
|
||||
|
||||
Authenticate with Qualcomm AI Hub:
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import Any
|
||||
|
||||
import yaml
|
||||
from sanitize_onnx import sanitize_onnx
|
||||
from training_metrics import write_training_metrics
|
||||
from ultralytics import YOLO # type: ignore[reportMissingImports]
|
||||
|
||||
|
||||
@@ -101,6 +102,7 @@ def main() -> None:
|
||||
if not trained_weights.exists():
|
||||
raise FileNotFoundError(f"Could not find trained weights in {save_dir / 'weights'}")
|
||||
|
||||
write_training_metrics(save_dir / "results.csv", model_dir / "training_metrics.json")
|
||||
copy_if_exists(trained_weights, model_dir / "best.pt")
|
||||
trained_model = YOLO(str(trained_weights))
|
||||
onnx_path = Path(trained_model.export(format="onnx", imgsz=args.imgsz))
|
||||
|
||||
82
examples/meter-detection/source/training_metrics.py
Normal file
82
examples/meter-detection/source/training_metrics.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import csv
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
METRIC_NAMES = {
|
||||
"metrics/precision(B)": "val.precision",
|
||||
"metrics/recall(B)": "val.recall",
|
||||
"metrics/mAP50(B)": "val.map50",
|
||||
"metrics/mAP50-95(B)": "val.map50_95",
|
||||
"train/box_loss": "train.box_loss",
|
||||
"train/cls_loss": "train.cls_loss",
|
||||
"train/dfl_loss": "train.dfl_loss",
|
||||
"val/box_loss": "val.box_loss",
|
||||
"val/cls_loss": "val.cls_loss",
|
||||
"val/dfl_loss": "val.dfl_loss",
|
||||
"time": "train.elapsed_seconds",
|
||||
}
|
||||
|
||||
|
||||
def write_training_metrics(results_csv: Path, destination: Path) -> None:
|
||||
steps = _read_metric_steps(results_csv)
|
||||
summary = _build_summary(steps)
|
||||
payload = {
|
||||
"schema_version": 1,
|
||||
"steps": steps,
|
||||
"summary": summary,
|
||||
}
|
||||
destination.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
||||
print(f"Saved {destination}")
|
||||
|
||||
|
||||
def _read_metric_steps(results_csv: Path) -> list[dict[str, Any]]:
|
||||
if not results_csv.is_file():
|
||||
raise FileNotFoundError(f"Could not find Ultralytics metrics history: {results_csv}")
|
||||
|
||||
steps: list[dict[str, Any]] = []
|
||||
with results_csv.open(newline="", encoding="utf-8") as csv_file:
|
||||
for row_index, raw_row in enumerate(csv.DictReader(csv_file)):
|
||||
row = {str(key).strip(): value for key, value in raw_row.items()}
|
||||
raw_epoch = row.pop("epoch", row_index)
|
||||
step = int(float(raw_epoch))
|
||||
metrics: dict[str, float] = {}
|
||||
for source_name, raw_value in row.items():
|
||||
if raw_value is None or not raw_value.strip():
|
||||
continue
|
||||
try:
|
||||
value = float(raw_value)
|
||||
except ValueError:
|
||||
continue
|
||||
if math.isfinite(value):
|
||||
metrics[METRIC_NAMES.get(source_name, _normalize_metric_name(source_name))] = value
|
||||
steps.append({"step": step, "metrics": metrics})
|
||||
return steps
|
||||
|
||||
|
||||
def _build_summary(steps: list[dict[str, Any]]) -> dict[str, float]:
|
||||
if not steps:
|
||||
return {}
|
||||
|
||||
summary: dict[str, float] = {}
|
||||
final_step = steps[-1]
|
||||
summary["summary.final_epoch"] = float(final_step["step"])
|
||||
for name, value in final_step["metrics"].items():
|
||||
summary[f"summary.final.{name}"] = value
|
||||
|
||||
scored_steps = [step for step in steps if "val.map50_95" in step["metrics"]]
|
||||
if scored_steps:
|
||||
best_step = max(scored_steps, key=lambda step: step["metrics"]["val.map50_95"])
|
||||
summary["summary.best_epoch"] = float(best_step["step"])
|
||||
summary["summary.best_val.map50_95"] = best_step["metrics"]["val.map50_95"]
|
||||
if "val.map50" in best_step["metrics"]:
|
||||
summary["summary.best_val.map50"] = best_step["metrics"]["val.map50"]
|
||||
return summary
|
||||
|
||||
|
||||
def _normalize_metric_name(name: str) -> str:
|
||||
normalized = name.replace("/", ".")
|
||||
normalized = re.sub(r"[^A-Za-z0-9_.-]+", "_", normalized)
|
||||
return normalized.strip("._") or "unnamed"
|
||||
@@ -1,3 +1,4 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
@@ -21,6 +22,8 @@ _STATUS_COLOR = {
|
||||
"Stopping": "yellow",
|
||||
"Stopped": "dim",
|
||||
}
|
||||
_TERMINAL_STATUSES = {"Completed", "Failed", "Stopped"}
|
||||
DEFAULT_POLL_INTERVAL_SECONDS = 30
|
||||
|
||||
|
||||
def _tracker(cfg):
|
||||
@@ -48,6 +51,57 @@ def _sagemaker_role_arn(config_path: str, cfg: Config) -> str:
|
||||
raise RuntimeError("SageMaker role not found in infra state. Run 'qc-cli infra setup' first.")
|
||||
|
||||
|
||||
def _print_training_status(status: sm_ops.TrainingJobStatus) -> None:
|
||||
color = _STATUS_COLOR.get(status.status, "white")
|
||||
CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]")
|
||||
CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]")
|
||||
if status.created:
|
||||
CONSOLE.print(f"Created: {status.created}")
|
||||
if status.model_artifacts:
|
||||
CONSOLE.print(f"Artifacts: {status.model_artifacts}")
|
||||
if status.failure_reason:
|
||||
CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]")
|
||||
|
||||
|
||||
def _finalize_terminal_job(
|
||||
*,
|
||||
config_path: str,
|
||||
cfg: Config,
|
||||
status: sm_ops.TrainingJobStatus,
|
||||
command: str,
|
||||
) -> None:
|
||||
if status.status not in _TERMINAL_STATUSES:
|
||||
return
|
||||
|
||||
st = state_ops.store(config_path)
|
||||
job_state = st.get_training_job(status.name)
|
||||
run_id = job_state.get("mlflow_run_id")
|
||||
if not run_id or job_state.get("mlflow_finalized_status"):
|
||||
return
|
||||
|
||||
tracker = _tracker(cfg)
|
||||
result = tracker.finalize_training_run(
|
||||
run_id=str(run_id),
|
||||
training_job_status=status,
|
||||
region=cfg.aws.region,
|
||||
profile=cfg.aws.profile,
|
||||
command=command,
|
||||
)
|
||||
updates = {"mlflow_finalized_status": status.status}
|
||||
if result.registered_model_version:
|
||||
updates["registered_model_version"] = result.registered_model_version
|
||||
st.update_training_job(status.name, **updates)
|
||||
|
||||
for warning in result.warnings:
|
||||
CONSOLE.print(f"[yellow]MLflow metrics warning: {warning}[/yellow]")
|
||||
if result.registered_model_version:
|
||||
st.set_latest_experiment_model_version(result.registered_model_version)
|
||||
CONSOLE.print(
|
||||
f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] "
|
||||
"([cyan]experiment-latest[/cyan])"
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def start(config: str = CONFIG_OPT) -> None:
|
||||
"""Submit a SageMaker training job."""
|
||||
@@ -123,37 +177,65 @@ def status(
|
||||
raise typer.Exit(1)
|
||||
|
||||
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
||||
color = _STATUS_COLOR.get(status.status, "white")
|
||||
|
||||
CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]")
|
||||
CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]")
|
||||
if status.created:
|
||||
CONSOLE.print(f"Created: {status.created}")
|
||||
if status.model_artifacts:
|
||||
CONSOLE.print(f"Artifacts: {status.model_artifacts}")
|
||||
if status.failure_reason:
|
||||
CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]")
|
||||
_print_training_status(status)
|
||||
_finalize_terminal_job(config_path=config, cfg=cfg, status=status, command="train status")
|
||||
|
||||
job_state = st.get_training_job(job_name)
|
||||
run_id = job_state.get("mlflow_run_id")
|
||||
already_registered = job_state.get("registered_model_version")
|
||||
if run_id and not already_registered and status.status in {"Completed", "Failed", "Stopped"}:
|
||||
tracker = _tracker(cfg)
|
||||
version = tracker.finalize_training_run(
|
||||
run_id=str(run_id),
|
||||
training_job_status=status,
|
||||
)
|
||||
updates = {"mlflow_finalized_status": status.status}
|
||||
if version:
|
||||
updates["registered_model_version"] = version
|
||||
st.update_training_job(job_name, **updates)
|
||||
if version:
|
||||
st.set_latest_experiment_model_version(version)
|
||||
CONSOLE.print(f"MLflow model version: [cyan]{version}[/cyan] ([cyan]experiment-latest[/cyan])")
|
||||
if run_id and cfg.mlflow.mode is not MlflowMode.disabled:
|
||||
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
|
||||
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
|
||||
|
||||
|
||||
@app.command()
|
||||
def wait(
|
||||
job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"),
|
||||
poll_interval: int = typer.Option(
|
||||
DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
"--poll-interval",
|
||||
min=1,
|
||||
help="Seconds between SageMaker status checks",
|
||||
),
|
||||
config: str = CONFIG_OPT,
|
||||
) -> None:
|
||||
"""Wait for a training job and finalize its MLflow run."""
|
||||
cfg = load_cfg(config)
|
||||
st = state_ops.store(config)
|
||||
if not job_name:
|
||||
job_name = st.get_last_training_job()
|
||||
if not job_name:
|
||||
CONSOLE.print(
|
||||
"[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]"
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
previous_status: str | None = None
|
||||
try:
|
||||
while True:
|
||||
training_status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
||||
if training_status.status != previous_status:
|
||||
color = _STATUS_COLOR.get(training_status.status, "white")
|
||||
CONSOLE.print(
|
||||
f"Job [cyan]{training_status.name}[/cyan]: "
|
||||
f"[{color}]{training_status.status}[/{color}]"
|
||||
)
|
||||
previous_status = training_status.status
|
||||
if training_status.status in _TERMINAL_STATUSES:
|
||||
_print_training_status(training_status)
|
||||
_finalize_terminal_job(
|
||||
config_path=config,
|
||||
cfg=cfg,
|
||||
status=training_status,
|
||||
command="train wait",
|
||||
)
|
||||
job_state = st.get_training_job(job_name)
|
||||
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
|
||||
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
|
||||
return
|
||||
time.sleep(poll_interval)
|
||||
except KeyboardInterrupt:
|
||||
CONSOLE.print("[yellow]Stopped waiting. The SageMaker training job is still running.[/yellow]")
|
||||
raise typer.Exit(130)
|
||||
|
||||
|
||||
@app.command(name="list")
|
||||
def list_jobs(
|
||||
limit: int = typer.Option(10, "--limit", "-n", help="Number of jobs to show"),
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from src.tracking.mlflow import MlflowTracker, NoopTracker, Tracker
|
||||
from src.tracking.mlflow import FinalizeResult, MlflowTracker, NoopTracker, Tracker
|
||||
|
||||
__all__ = ["MlflowTracker", "NoopTracker", "Tracker"]
|
||||
__all__ = ["FinalizeResult", "MlflowTracker", "NoopTracker", "Tracker"]
|
||||
|
||||
93
src/tracking/metrics.py
Normal file
93
src/tracking/metrics.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import json
|
||||
import math
|
||||
import tarfile
|
||||
from dataclasses import dataclass
|
||||
from pathlib import PurePosixPath
|
||||
from typing import Any
|
||||
|
||||
METRICS_ARTIFACT_NAME = "training_metrics.json"
|
||||
METRICS_SCHEMA_VERSION = 1
|
||||
MAX_METRICS_ARTIFACT_BYTES = 10 * 1024 * 1024
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MetricStep:
|
||||
step: int
|
||||
metrics: dict[str, float]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TrainingMetrics:
|
||||
steps: list[MetricStep]
|
||||
summary: dict[str, float]
|
||||
raw: dict[str, Any]
|
||||
|
||||
|
||||
def parse_training_metrics(data: bytes) -> TrainingMetrics:
|
||||
try:
|
||||
value = json.loads(data)
|
||||
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
|
||||
raise ValueError(f"{METRICS_ARTIFACT_NAME} is not valid JSON") from exc
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError(f"{METRICS_ARTIFACT_NAME} must contain a JSON object")
|
||||
if value.get("schema_version") != METRICS_SCHEMA_VERSION:
|
||||
raise ValueError(f"Unsupported training metrics schema version: {value.get('schema_version')!r}")
|
||||
|
||||
raw_steps = value.get("steps")
|
||||
if not isinstance(raw_steps, list):
|
||||
raise ValueError("training metrics 'steps' must be a list")
|
||||
|
||||
steps: list[MetricStep] = []
|
||||
previous_step: int | None = None
|
||||
for index, raw_step in enumerate(raw_steps):
|
||||
if not isinstance(raw_step, dict):
|
||||
raise ValueError(f"training metrics step {index} must be an object")
|
||||
step = raw_step.get("step")
|
||||
if isinstance(step, bool) or not isinstance(step, int) or step < 0:
|
||||
raise ValueError(f"training metrics step {index} has an invalid 'step'")
|
||||
if previous_step is not None and step <= previous_step:
|
||||
raise ValueError("training metrics steps must be unique and strictly increasing")
|
||||
metrics = _numeric_metrics(raw_step.get("metrics"), f"training metrics step {step}")
|
||||
steps.append(MetricStep(step=step, metrics=metrics))
|
||||
previous_step = step
|
||||
|
||||
summary = _numeric_metrics(value.get("summary", {}), "training metrics summary")
|
||||
return TrainingMetrics(steps=steps, summary=summary, raw=value)
|
||||
|
||||
|
||||
def read_training_metrics_from_tar(archive_path: str) -> bytes | None:
|
||||
with tarfile.open(archive_path, mode="r:*") as archive:
|
||||
matches = [
|
||||
member
|
||||
for member in archive.getmembers()
|
||||
if member.isfile() and PurePosixPath(member.name).name == METRICS_ARTIFACT_NAME
|
||||
]
|
||||
if not matches:
|
||||
return None
|
||||
if len(matches) > 1:
|
||||
raise ValueError(f"Model archive contains multiple {METRICS_ARTIFACT_NAME} files")
|
||||
if matches[0].size > MAX_METRICS_ARTIFACT_BYTES:
|
||||
raise ValueError(
|
||||
f"{METRICS_ARTIFACT_NAME} exceeds the {MAX_METRICS_ARTIFACT_BYTES}-byte size limit"
|
||||
)
|
||||
extracted = archive.extractfile(matches[0])
|
||||
if extracted is None:
|
||||
raise ValueError(f"Could not read {METRICS_ARTIFACT_NAME} from model archive")
|
||||
return extracted.read()
|
||||
|
||||
|
||||
def _numeric_metrics(value: Any, context: str) -> dict[str, float]:
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError(f"{context} 'metrics' must be an object")
|
||||
|
||||
metrics: dict[str, float] = {}
|
||||
for raw_name, raw_value in value.items():
|
||||
if not isinstance(raw_name, str) or not raw_name:
|
||||
raise ValueError(f"{context} contains an invalid metric name")
|
||||
if isinstance(raw_value, bool) or not isinstance(raw_value, int | float):
|
||||
raise ValueError(f"{context} metric '{raw_name}' must be numeric")
|
||||
metric_value = float(raw_value)
|
||||
if not math.isfinite(metric_value):
|
||||
raise ValueError(f"{context} metric '{raw_name}' must be finite")
|
||||
metrics[raw_name] = metric_value
|
||||
return metrics
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
|
||||
@@ -6,13 +7,29 @@ import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
from src.aws import mlflow as aws_mlflow
|
||||
from src.aws import s3
|
||||
from src.config import Config, MlflowMode
|
||||
from src.tracking.metrics import METRICS_ARTIFACT_NAME, parse_training_metrics, read_training_metrics_from_tar
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FinalizeResult:
|
||||
registered_model_version: str | None = None
|
||||
warnings: tuple[str, ...] = ()
|
||||
|
||||
|
||||
class Tracker(Protocol):
|
||||
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: ...
|
||||
|
||||
def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None: ...
|
||||
def finalize_training_run(
|
||||
self,
|
||||
*,
|
||||
run_id: str | None,
|
||||
training_job_status: Any,
|
||||
region: str,
|
||||
profile: str,
|
||||
command: str,
|
||||
) -> FinalizeResult: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -20,8 +37,16 @@ class NoopTracker:
|
||||
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
|
||||
return None
|
||||
|
||||
def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None:
|
||||
return None
|
||||
def finalize_training_run(
|
||||
self,
|
||||
*,
|
||||
run_id: str | None,
|
||||
training_job_status: Any,
|
||||
region: str,
|
||||
profile: str,
|
||||
command: str,
|
||||
) -> FinalizeResult:
|
||||
return FinalizeResult()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -88,10 +113,19 @@ class MlflowTracker:
|
||||
mlflow.end_run()
|
||||
return run_id
|
||||
|
||||
def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None:
|
||||
def finalize_training_run(
|
||||
self,
|
||||
*,
|
||||
run_id: str | None,
|
||||
training_job_status: Any,
|
||||
region: str,
|
||||
profile: str,
|
||||
command: str,
|
||||
) -> FinalizeResult:
|
||||
if not run_id:
|
||||
return None
|
||||
return FinalizeResult()
|
||||
|
||||
warnings: list[str] = []
|
||||
with mlflow.start_run(run_id=run_id):
|
||||
self._log_params(
|
||||
{
|
||||
@@ -103,14 +137,22 @@ class MlflowTracker:
|
||||
}
|
||||
)
|
||||
self._log_final_metrics(training_job_status.raw)
|
||||
mlflow.set_tag("qc_cli.command", "train status")
|
||||
if training_job_status.status == "Completed" and training_job_status.model_artifacts:
|
||||
warnings.extend(
|
||||
self._log_training_metrics(
|
||||
training_job_status.model_artifacts,
|
||||
region=region,
|
||||
profile=profile,
|
||||
)
|
||||
)
|
||||
mlflow.set_tag("qc_cli.command", command)
|
||||
|
||||
if training_job_status.status != "Completed" or not training_job_status.model_artifacts:
|
||||
mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status)
|
||||
return None
|
||||
return FinalizeResult(warnings=tuple(warnings))
|
||||
|
||||
if not self.register_trained_models:
|
||||
return None
|
||||
return FinalizeResult(warnings=tuple(warnings))
|
||||
|
||||
client = MlflowClient()
|
||||
self._ensure_registered_model(client, self.registered_model_name)
|
||||
@@ -129,7 +171,7 @@ class MlflowTracker:
|
||||
client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number)
|
||||
mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name)
|
||||
mlflow.set_tag("qc_cli.registered_model_version", version_number)
|
||||
return version_number
|
||||
return FinalizeResult(registered_model_version=version_number, warnings=tuple(warnings))
|
||||
|
||||
def _log_params(self, params: dict[str, Any]) -> None:
|
||||
cleaned = {key: str(value) for key, value in params.items() if value is not None}
|
||||
@@ -146,6 +188,29 @@ class MlflowTracker:
|
||||
if metrics:
|
||||
mlflow.log_metrics(metrics)
|
||||
|
||||
def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> list[str]:
|
||||
try:
|
||||
with tempfile.TemporaryDirectory(prefix="qc-cli-metrics-") as temp_dir:
|
||||
archive_path = s3.download_file(
|
||||
region,
|
||||
profile,
|
||||
model_artifacts,
|
||||
os.path.join(temp_dir, "model.tar.gz"),
|
||||
)
|
||||
metrics_data = read_training_metrics_from_tar(archive_path)
|
||||
if metrics_data is None:
|
||||
return [f"No {METRICS_ARTIFACT_NAME} found in the SageMaker model artifact."]
|
||||
metrics = parse_training_metrics(metrics_data)
|
||||
for metric_step in metrics.steps:
|
||||
if metric_step.metrics:
|
||||
mlflow.log_metrics(metric_step.metrics, step=metric_step.step)
|
||||
if metrics.summary:
|
||||
mlflow.log_metrics(metrics.summary)
|
||||
mlflow.log_dict(metrics.raw, METRICS_ARTIFACT_NAME)
|
||||
except Exception as exc:
|
||||
return [f"Could not import training metrics: {exc}"]
|
||||
return []
|
||||
|
||||
def _ensure_registered_model(self, client: MlflowClient, name: str) -> None:
|
||||
try:
|
||||
client.get_registered_model(name)
|
||||
|
||||
Reference in New Issue
Block a user