include-metrics-from-training (#6)
Reviewed-on: #6
This commit was merged in pull request #6.
This commit is contained in:
45
README.md
45
README.md
@@ -105,7 +105,11 @@ mlflow:
|
|||||||
tracking_server_name: your-tracking-server-name
|
tracking_server_name: your-tracking-server-name
|
||||||
```
|
```
|
||||||
|
|
||||||
When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. `train status` finalizes that run once the job reaches a terminal state and registers completed model artifacts as experiment model versions using the `experiment-latest` MLflow alias. An experiment version is an immutable trained-source artifact; it records that training produced a model, not that the model is better than earlier versions or ready for release.
|
When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. Metric upload through
|
||||||
|
`train start --upload-metrics` or `mlflow upload-metrics` finalizes that run and registers completed model artifacts
|
||||||
|
as experiment model versions using the `experiment-latest` MLflow alias. `train status` reads SageMaker status only.
|
||||||
|
An experiment version is an immutable trained-source artifact; it records that training produced a model, not that
|
||||||
|
the model is better than earlier versions or ready for release.
|
||||||
|
|
||||||
To open the managed SageMaker MLflow UI, request a fresh presigned URL:
|
To open the managed SageMaker MLflow UI, request a fresh presigned URL:
|
||||||
|
|
||||||
@@ -128,9 +132,14 @@ qc-cli init --force Overwrite an existing config file
|
|||||||
### `mlflow`
|
### `mlflow`
|
||||||
|
|
||||||
```
|
```
|
||||||
qc-cli mlflow open Open a presigned MLflow UI URL in a browser
|
qc-cli mlflow open Open a presigned MLflow UI URL
|
||||||
|
qc-cli mlflow upload-metrics [job-name] Upload completed training metrics
|
||||||
```
|
```
|
||||||
|
|
||||||
|
`mlflow upload-metrics` defaults to the last submitted training job. It creates or recovers the job's MLflow run,
|
||||||
|
imports `training_metrics.json` from the SageMaker model artifact, and records successful upload in `.qc-cli.json`.
|
||||||
|
Use `--force` to upload the metrics again.
|
||||||
|
|
||||||
### `infra`
|
### `infra`
|
||||||
|
|
||||||
```
|
```
|
||||||
@@ -163,6 +172,7 @@ Uploads use `s3.bucket` and `s3.data_prefix` from `config.yaml`. File uploads de
|
|||||||
|
|
||||||
```
|
```
|
||||||
qc-cli train start Submit a SageMaker training job
|
qc-cli train start Submit a SageMaker training job
|
||||||
|
qc-cli train start --upload-metrics Submit, wait, and upload metrics
|
||||||
qc-cli train status [job-name] Show job status; defaults to the last submitted job
|
qc-cli train status [job-name] Show job status; defaults to the last submitted job
|
||||||
qc-cli train list List recent training jobs
|
qc-cli train list List recent training jobs
|
||||||
qc-cli train list --limit 3 Show a custom number of recent jobs
|
qc-cli train list --limit 3 Show a custom number of recent jobs
|
||||||
@@ -170,6 +180,8 @@ qc-cli train list --limit 3 Show a custom number of recent jobs
|
|||||||
|
|
||||||
`train start` uses `s3://<bucket>/<data_prefix>/` as the training channel and writes outputs under `s3://<bucket>/<model_prefix>/`. If `sagemaker.training.source_dir` is set, the CLI packages that directory, uploads it beside the job output prefix, and passes `sagemaker_program`/`sagemaker_submit_directory` to the SageMaker container.
|
`train start` uses `s3://<bucket>/<data_prefix>/` as the training channel and writes outputs under `s3://<bucket>/<model_prefix>/`. If `sagemaker.training.source_dir` is set, the CLI packages that directory, uploads it beside the job output prefix, and passes `sagemaker_program`/`sagemaker_submit_directory` to the SageMaker container.
|
||||||
|
|
||||||
|
`train start --upload-metrics` checks SageMaker every 30 seconds by default, then uploads metrics after completion. Use `--poll-interval <seconds>` to choose another positive interval. Stopping the local command does not stop the SageMaker job.
|
||||||
|
|
||||||
The expected output artifact is SageMaker’s `model.tar.gz`, normally containing the trained model file your container writes to `/opt/ml/model`.
|
The expected output artifact is SageMaker’s `model.tar.gz`, normally containing the trained model file your container writes to `/opt/ml/model`.
|
||||||
|
|
||||||
### `ai-hub`
|
### `ai-hub`
|
||||||
@@ -216,13 +228,34 @@ The CLI uses neutral experiment naming for trained artifacts and reserves releas
|
|||||||
Current behavior:
|
Current behavior:
|
||||||
|
|
||||||
1. `qc-cli train start` submits a SageMaker training job.
|
1. `qc-cli train start` submits a SageMaker training job.
|
||||||
2. `qc-cli train status` finalizes the MLflow run after the job reaches a terminal state.
|
2. `qc-cli train status` reads and displays SageMaker status only; it does not contact MLflow.
|
||||||
3. If the job completed and `mlflow.register_trained_models` is enabled, the SageMaker `model.tar.gz` is registered as a new MLflow model version with:
|
3. `qc-cli train start --upload-metrics` polls every 30 seconds by default, then uploads per-epoch metrics after completion.
|
||||||
|
4. `qc-cli mlflow upload-metrics [job-name]` uploads or retries metrics for an existing completed job.
|
||||||
|
5. The metrics upload workflow finalizes the MLflow run and, when `mlflow.register_trained_models` is enabled, registers the SageMaker `model.tar.gz` as a new MLflow model version with:
|
||||||
- `qc_cli.stage=experiment`
|
- `qc_cli.stage=experiment`
|
||||||
- `qc_cli.artifact_kind=trained_source`
|
- `qc_cli.artifact_kind=trained_source`
|
||||||
- `qc_cli.source=sagemaker`
|
- `qc_cli.source=sagemaker`
|
||||||
4. The MLflow alias `experiment-latest` points at the most recently registered experiment version.
|
6. The MLflow alias `experiment-latest` points at the most recently registered experiment version.
|
||||||
5. AI Hub upload commands create deployable derived artifacts from a trained-source experiment or local ONNX model.
|
7. AI Hub upload commands create deployable derived artifacts from a trained-source experiment or local ONNX model.
|
||||||
|
|
||||||
|
Training scripts can include a `training_metrics.json` file in the SageMaker model directory. When present, the
|
||||||
|
explicit metrics upload command logs its ordered metrics to the associated MLflow run using each epoch as the MLflow
|
||||||
|
step and stores the JSON as a run artifact:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"schema_version": 1,
|
||||||
|
"steps": [
|
||||||
|
{"step": 0, "metrics": {"val.precision": 0.72, "val.recall": 0.68}}
|
||||||
|
],
|
||||||
|
"summary": {"summary.best_epoch": 0}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Metric names must be non-empty strings, values must be finite numbers, and steps must be non-negative, unique, and
|
||||||
|
strictly increasing. If the file is missing, the command uploads the final metrics reported by SageMaker and continues
|
||||||
|
model registration without per-epoch history. A malformed metrics artifact still fails the upload command without
|
||||||
|
affecting the trained model or model registration.
|
||||||
|
|
||||||
Future release aliases such as `v1` or `production` can point at a selected deployable artifact.
|
Future release aliases such as `v1` or `production` can point at a selected deployable artifact.
|
||||||
|
|
||||||
|
|||||||
@@ -153,6 +153,20 @@ Or pass the job name explicitly:
|
|||||||
qc-cli train status qc-cli-YYYYMMDD-HHMMSS
|
qc-cli train status qc-cli-YYYYMMDD-HHMMSS
|
||||||
```
|
```
|
||||||
|
|
||||||
|
To submit the job, wait for completion, and automatically import metrics and register the model, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
qc-cli train start --upload-metrics
|
||||||
|
```
|
||||||
|
|
||||||
|
The default polling interval is 30 seconds. It can be changed with `--poll-interval <seconds>`.
|
||||||
|
|
||||||
|
The metrics can be also submitted using:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
qc-cli mlflow upload-metrics
|
||||||
|
```
|
||||||
|
|
||||||
## SageMaker Outputs
|
## SageMaker Outputs
|
||||||
|
|
||||||
When the job completes, SageMaker packages the files written under `/opt/ml/model` into `model.tar.gz`.
|
When the job completes, SageMaker packages the files written under `/opt/ml/model` into `model.tar.gz`.
|
||||||
@@ -163,10 +177,15 @@ This example writes:
|
|||||||
best.pt
|
best.pt
|
||||||
model.onnx
|
model.onnx
|
||||||
metrics.json
|
metrics.json
|
||||||
|
training_metrics.json
|
||||||
```
|
```
|
||||||
|
|
||||||
The archive is stored under the configured `s3.model_prefix`.
|
The archive is stored under the configured `s3.model_prefix`.
|
||||||
|
|
||||||
|
The `mlflow upload-metrics` command imports `training_metrics.json`, which provides per-epoch training and validation
|
||||||
|
losses, precision, recall, mAP@0.50, mAP@0.50:0.95, and learning rates. For object detection, mAP and precision/recall
|
||||||
|
are more meaningful than classification accuracy when assessing model quality.
|
||||||
|
|
||||||
## 6. Configure Qualcomm AI Hub
|
## 6. Configure Qualcomm AI Hub
|
||||||
|
|
||||||
Authenticate with Qualcomm AI Hub:
|
Authenticate with Qualcomm AI Hub:
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from typing import Any
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from sanitize_onnx import sanitize_onnx
|
from sanitize_onnx import sanitize_onnx
|
||||||
|
from training_metrics import write_training_metrics
|
||||||
from ultralytics import YOLO # type: ignore[reportMissingImports]
|
from ultralytics import YOLO # type: ignore[reportMissingImports]
|
||||||
|
|
||||||
|
|
||||||
@@ -101,6 +102,7 @@ def main() -> None:
|
|||||||
if not trained_weights.exists():
|
if not trained_weights.exists():
|
||||||
raise FileNotFoundError(f"Could not find trained weights in {save_dir / 'weights'}")
|
raise FileNotFoundError(f"Could not find trained weights in {save_dir / 'weights'}")
|
||||||
|
|
||||||
|
write_training_metrics(save_dir / "results.csv", model_dir / "training_metrics.json")
|
||||||
copy_if_exists(trained_weights, model_dir / "best.pt")
|
copy_if_exists(trained_weights, model_dir / "best.pt")
|
||||||
trained_model = YOLO(str(trained_weights))
|
trained_model = YOLO(str(trained_weights))
|
||||||
onnx_path = Path(trained_model.export(format="onnx", imgsz=args.imgsz))
|
onnx_path = Path(trained_model.export(format="onnx", imgsz=args.imgsz))
|
||||||
|
|||||||
82
examples/meter-detection/source/training_metrics.py
Normal file
82
examples/meter-detection/source/training_metrics.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
METRIC_NAMES = {
|
||||||
|
"metrics/precision(B)": "val.precision",
|
||||||
|
"metrics/recall(B)": "val.recall",
|
||||||
|
"metrics/mAP50(B)": "val.map50",
|
||||||
|
"metrics/mAP50-95(B)": "val.map50_95",
|
||||||
|
"train/box_loss": "train.box_loss",
|
||||||
|
"train/cls_loss": "train.cls_loss",
|
||||||
|
"train/dfl_loss": "train.dfl_loss",
|
||||||
|
"val/box_loss": "val.box_loss",
|
||||||
|
"val/cls_loss": "val.cls_loss",
|
||||||
|
"val/dfl_loss": "val.dfl_loss",
|
||||||
|
"time": "train.elapsed_seconds",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def write_training_metrics(results_csv: Path, destination: Path) -> None:
|
||||||
|
steps = _read_metric_steps(results_csv)
|
||||||
|
summary = _build_summary(steps)
|
||||||
|
payload = {
|
||||||
|
"schema_version": 1,
|
||||||
|
"steps": steps,
|
||||||
|
"summary": summary,
|
||||||
|
}
|
||||||
|
destination.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
||||||
|
print(f"Saved {destination}")
|
||||||
|
|
||||||
|
|
||||||
|
def _read_metric_steps(results_csv: Path) -> list[dict[str, Any]]:
|
||||||
|
if not results_csv.is_file():
|
||||||
|
raise FileNotFoundError(f"Could not find Ultralytics metrics history: {results_csv}")
|
||||||
|
|
||||||
|
steps: list[dict[str, Any]] = []
|
||||||
|
with results_csv.open(newline="", encoding="utf-8") as csv_file:
|
||||||
|
for row_index, raw_row in enumerate(csv.DictReader(csv_file)):
|
||||||
|
row = {str(key).strip(): value for key, value in raw_row.items()}
|
||||||
|
raw_epoch = row.pop("epoch", row_index)
|
||||||
|
step = int(float(raw_epoch))
|
||||||
|
metrics: dict[str, float] = {}
|
||||||
|
for source_name, raw_value in row.items():
|
||||||
|
if raw_value is None or not raw_value.strip():
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
value = float(raw_value)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
if math.isfinite(value):
|
||||||
|
metrics[METRIC_NAMES.get(source_name, _normalize_metric_name(source_name))] = value
|
||||||
|
steps.append({"step": step, "metrics": metrics})
|
||||||
|
return steps
|
||||||
|
|
||||||
|
|
||||||
|
def _build_summary(steps: list[dict[str, Any]]) -> dict[str, float]:
|
||||||
|
if not steps:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
summary: dict[str, float] = {}
|
||||||
|
final_step = steps[-1]
|
||||||
|
summary["summary.final_epoch"] = float(final_step["step"])
|
||||||
|
for name, value in final_step["metrics"].items():
|
||||||
|
summary[f"summary.final.{name}"] = value
|
||||||
|
|
||||||
|
scored_steps = [step for step in steps if "val.map50_95" in step["metrics"]]
|
||||||
|
if scored_steps:
|
||||||
|
best_step = max(scored_steps, key=lambda step: step["metrics"]["val.map50_95"])
|
||||||
|
summary["summary.best_epoch"] = float(best_step["step"])
|
||||||
|
summary["summary.best_val.map50_95"] = best_step["metrics"]["val.map50_95"]
|
||||||
|
if "val.map50" in best_step["metrics"]:
|
||||||
|
summary["summary.best_val.map50"] = best_step["metrics"]["val.map50"]
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_metric_name(name: str) -> str:
|
||||||
|
normalized = name.replace("/", ".")
|
||||||
|
normalized = re.sub(r"[^A-Za-z0-9_.-]+", "_", normalized)
|
||||||
|
return normalized.strip("._") or "unnamed"
|
||||||
@@ -1,3 +1,6 @@
|
|||||||
|
import os
|
||||||
|
from collections.abc import Generator
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
@@ -34,3 +37,38 @@ def create_presigned_tracking_server_url(region: str, profile: str, name: str) -
|
|||||||
client = boto3.Session(profile_name=profile, region_name=region).client("sagemaker")
|
client = boto3.Session(profile_name=profile, region_name=region).client("sagemaker")
|
||||||
response = client.create_presigned_mlflow_tracking_server_url(TrackingServerName=name)
|
response = client.create_presigned_mlflow_tracking_server_url(TrackingServerName=name)
|
||||||
return str(response["AuthorizedUrl"])
|
return str(response["AuthorizedUrl"])
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def tracking_auth_env(profile: str, region: str) -> Generator[None]:
|
||||||
|
credentials = boto3.Session(profile_name=profile, region_name=region).get_credentials()
|
||||||
|
if credentials is None:
|
||||||
|
raise RuntimeError(f"AWS credentials could not be resolved for profile '{profile}'.")
|
||||||
|
|
||||||
|
frozen_credentials = credentials.get_frozen_credentials()
|
||||||
|
if not frozen_credentials.access_key or not frozen_credentials.secret_key:
|
||||||
|
raise RuntimeError(f"AWS credentials are incomplete for profile '{profile}'.")
|
||||||
|
|
||||||
|
env_updates = {
|
||||||
|
"AWS_PROFILE": profile,
|
||||||
|
"AWS_DEFAULT_REGION": region,
|
||||||
|
"AWS_REGION": region,
|
||||||
|
"AWS_ACCESS_KEY_ID": frozen_credentials.access_key,
|
||||||
|
"AWS_SECRET_ACCESS_KEY": frozen_credentials.secret_key,
|
||||||
|
}
|
||||||
|
if frozen_credentials.token:
|
||||||
|
env_updates["AWS_SESSION_TOKEN"] = frozen_credentials.token
|
||||||
|
|
||||||
|
restore_keys = set(env_updates) | {"AWS_SESSION_TOKEN"}
|
||||||
|
previous_env = {key: os.environ.get(key) for key in restore_keys}
|
||||||
|
try:
|
||||||
|
os.environ.update(env_updates)
|
||||||
|
if not frozen_credentials.token:
|
||||||
|
os.environ.pop("AWS_SESSION_TOKEN", None)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
for key, value in previous_env.items():
|
||||||
|
if value is None:
|
||||||
|
os.environ.pop(key, None)
|
||||||
|
else:
|
||||||
|
os.environ[key] = value
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
"""Cloud provider adapters."""
|
||||||
|
|||||||
77
src/cloud/mlflow.py
Normal file
77
src/cloud/mlflow.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
from contextlib import AbstractContextManager
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
from src.aws import mlflow as aws_mlflow
|
||||||
|
from src.config import Config
|
||||||
|
|
||||||
|
|
||||||
|
class MlflowTrackingBackend(Protocol):
|
||||||
|
@property
|
||||||
|
def provider_name(self) -> str: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def profile(self) -> str: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def region(self) -> str: ...
|
||||||
|
|
||||||
|
def get_tracking_uri(self, tracking_server_name: str) -> str: ...
|
||||||
|
|
||||||
|
def auth_env(self) -> AbstractContextManager[None]: ...
|
||||||
|
|
||||||
|
def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]: ...
|
||||||
|
|
||||||
|
def training_run_tags(self, training_job: Any) -> dict[str, Any]: ...
|
||||||
|
|
||||||
|
def training_status_params(self, training_job_status: Any) -> dict[str, Any]: ...
|
||||||
|
|
||||||
|
def model_version_tags(self, training_job_status: Any) -> dict[str, Any]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AwsMlflowTrackingBackend:
|
||||||
|
profile: str
|
||||||
|
region: str
|
||||||
|
provider_name: str = "aws"
|
||||||
|
|
||||||
|
def get_tracking_uri(self, tracking_server_name: str) -> str:
|
||||||
|
return aws_mlflow.get_tracking_server_arn(self.region, self.profile, tracking_server_name)
|
||||||
|
|
||||||
|
def auth_env(self) -> AbstractContextManager[None]:
|
||||||
|
return aws_mlflow.tracking_auth_env(self.profile, self.region)
|
||||||
|
|
||||||
|
def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"provider.name": self.provider_name,
|
||||||
|
"provider.region": region,
|
||||||
|
"provider.profile": profile,
|
||||||
|
"sagemaker.role_arn": role_arn,
|
||||||
|
"sagemaker.job_name": training_job.job_name,
|
||||||
|
"sagemaker.training_image": training_job.image_uri,
|
||||||
|
"sagemaker.instance_type": training_job.instance_type,
|
||||||
|
"sagemaker.instance_count": training_job.instance_count,
|
||||||
|
"sagemaker.s3_train_uri": training_job.s3_train_uri,
|
||||||
|
"sagemaker.s3_output_path": training_job.s3_output_path,
|
||||||
|
"sagemaker.entry_point": training_job.entry_point,
|
||||||
|
"sagemaker.source_dir": training_job.source_dir,
|
||||||
|
}
|
||||||
|
|
||||||
|
def training_run_tags(self, training_job: Any) -> dict[str, Any]:
|
||||||
|
return {"sagemaker.job_name": training_job.job_name}
|
||||||
|
|
||||||
|
def training_status_params(self, training_job_status: Any) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"sagemaker.training_status": training_job_status.status,
|
||||||
|
"sagemaker.created_at": training_job_status.created,
|
||||||
|
"sagemaker.modified_at": training_job_status.modified,
|
||||||
|
"sagemaker.model_artifacts": training_job_status.model_artifacts,
|
||||||
|
"sagemaker.failure_reason": training_job_status.failure_reason,
|
||||||
|
}
|
||||||
|
|
||||||
|
def model_version_tags(self, training_job_status: Any) -> dict[str, Any]:
|
||||||
|
return {"sagemaker.job_name": training_job_status.name}
|
||||||
|
|
||||||
|
|
||||||
|
def mlflow_tracking_backend_from_config(cfg: Config) -> MlflowTrackingBackend:
|
||||||
|
return AwsMlflowTrackingBackend(profile=cfg.aws.profile, region=cfg.aws.region)
|
||||||
@@ -2,8 +2,11 @@ import webbrowser
|
|||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
|
from src import state as state_ops
|
||||||
from src.aws import mlflow as aws_mlflow
|
from src.aws import mlflow as aws_mlflow
|
||||||
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
||||||
|
from src.config import MlflowMode
|
||||||
|
from src.tracking.upload import upload_training_metrics
|
||||||
|
|
||||||
app = typer.Typer(help="Manage MLflow tracking server access")
|
app = typer.Typer(help="Manage MLflow tracking server access")
|
||||||
|
|
||||||
@@ -39,3 +42,54 @@ def open_mlflow(config: str = CONFIG_OPT) -> None:
|
|||||||
CONSOLE.print("[green]✓[/green] Opened MLflow UI in your browser.")
|
CONSOLE.print("[green]✓[/green] Opened MLflow UI in your browser.")
|
||||||
else:
|
else:
|
||||||
CONSOLE.print("[yellow]Could not open a browser automatically. Open the URL above manually.[/yellow]")
|
CONSOLE.print("[yellow]Could not open a browser automatically. Open the URL above manually.[/yellow]")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command(name="upload-metrics")
|
||||||
|
def upload_metrics(
|
||||||
|
job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"),
|
||||||
|
force: bool = typer.Option(False, "--force", help="Upload again even if this job is marked as uploaded"),
|
||||||
|
config: str = CONFIG_OPT,
|
||||||
|
) -> None:
|
||||||
|
"""Upload a completed training job's metric history to MLflow."""
|
||||||
|
cfg = load_cfg(config)
|
||||||
|
if cfg.mlflow.mode is MlflowMode.disabled:
|
||||||
|
CONSOLE.print("[red]MLflow is disabled in config.yaml.[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
st = state_ops.store(config)
|
||||||
|
if not job_name:
|
||||||
|
job_name = st.get_last_training_job()
|
||||||
|
if not job_name:
|
||||||
|
CONSOLE.print(
|
||||||
|
"[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]"
|
||||||
|
)
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
if st.get_training_job(job_name).get("mlflow_metrics_uploaded") and not force:
|
||||||
|
CONSOLE.print(f"[green]✓[/green] Metrics already uploaded for [cyan]{job_name}[/cyan].")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = upload_training_metrics(
|
||||||
|
job_name=job_name,
|
||||||
|
config_path=config,
|
||||||
|
cfg=cfg,
|
||||||
|
force=force,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
if result.metrics_history_uploaded:
|
||||||
|
CONSOLE.print(f"[green]✓[/green] Uploaded training metrics for [cyan]{job_name}[/cyan].")
|
||||||
|
else:
|
||||||
|
CONSOLE.print(
|
||||||
|
f"[yellow]No training_metrics.json was found in the SageMaker model artifact for "
|
||||||
|
f"[cyan]{job_name}[/cyan]. Uploaded SageMaker final metrics only.[/yellow]"
|
||||||
|
)
|
||||||
|
CONSOLE.print(f"MLflow run: [cyan]{result.run_id}[/cyan]")
|
||||||
|
if result.registered_model_version:
|
||||||
|
CONSOLE.print(
|
||||||
|
f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] "
|
||||||
|
"([cyan]experiment-latest[/cyan])"
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -11,6 +12,7 @@ from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
|||||||
from src.config import Config, MlflowMode
|
from src.config import Config, MlflowMode
|
||||||
from src.infra.state import read_infra_state
|
from src.infra.state import read_infra_state
|
||||||
from src.tracking.mlflow import MlflowTracker
|
from src.tracking.mlflow import MlflowTracker
|
||||||
|
from src.tracking.upload import upload_training_metrics
|
||||||
|
|
||||||
app = typer.Typer(help="Manage SageMaker training jobs")
|
app = typer.Typer(help="Manage SageMaker training jobs")
|
||||||
|
|
||||||
@@ -21,6 +23,8 @@ _STATUS_COLOR = {
|
|||||||
"Stopping": "yellow",
|
"Stopping": "yellow",
|
||||||
"Stopped": "dim",
|
"Stopped": "dim",
|
||||||
}
|
}
|
||||||
|
_TERMINAL_STATUSES = {"Completed", "Failed", "Stopped"}
|
||||||
|
DEFAULT_POLL_INTERVAL_SECONDS = 30
|
||||||
|
|
||||||
|
|
||||||
def _tracker(cfg):
|
def _tracker(cfg):
|
||||||
@@ -48,11 +52,100 @@ def _sagemaker_role_arn(config_path: str, cfg: Config) -> str:
|
|||||||
raise RuntimeError("SageMaker role not found in infra state. Run 'qc-cli infra setup' first.")
|
raise RuntimeError("SageMaker role not found in infra state. Run 'qc-cli infra setup' first.")
|
||||||
|
|
||||||
|
|
||||||
|
def _print_training_status(status: sm_ops.TrainingJobStatus) -> None:
|
||||||
|
color = _STATUS_COLOR.get(status.status, "white")
|
||||||
|
CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]")
|
||||||
|
CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]")
|
||||||
|
if status.created:
|
||||||
|
CONSOLE.print(f"Created: {status.created}")
|
||||||
|
if status.model_artifacts:
|
||||||
|
CONSOLE.print(f"Artifacts: {status.model_artifacts}")
|
||||||
|
if status.failure_reason:
|
||||||
|
CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]")
|
||||||
|
|
||||||
|
|
||||||
|
def _wait_and_upload_metrics(
|
||||||
|
*,
|
||||||
|
job_name: str,
|
||||||
|
poll_interval: int,
|
||||||
|
config_path: str,
|
||||||
|
cfg: Config,
|
||||||
|
) -> None:
|
||||||
|
st = state_ops.store(config_path)
|
||||||
|
previous_status: str | None = None
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
training_status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
||||||
|
if training_status.status != previous_status:
|
||||||
|
color = _STATUS_COLOR.get(training_status.status, "white")
|
||||||
|
CONSOLE.print(
|
||||||
|
f"Job [cyan]{training_status.name}[/cyan]: "
|
||||||
|
f"[{color}]{training_status.status}[/{color}]"
|
||||||
|
)
|
||||||
|
previous_status = training_status.status
|
||||||
|
if training_status.status in _TERMINAL_STATUSES:
|
||||||
|
_print_training_status(training_status)
|
||||||
|
if training_status.status != "Completed":
|
||||||
|
raise typer.Exit(1)
|
||||||
|
try:
|
||||||
|
result = upload_training_metrics(
|
||||||
|
job_name=job_name,
|
||||||
|
config_path=config_path,
|
||||||
|
cfg=cfg,
|
||||||
|
)
|
||||||
|
if result.metrics_history_uploaded:
|
||||||
|
CONSOLE.print(
|
||||||
|
f"[green]✓[/green] Uploaded training metrics to MLflow run "
|
||||||
|
f"[cyan]{result.run_id}[/cyan]."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
CONSOLE.print(
|
||||||
|
"[yellow]No training_metrics.json was found in the SageMaker model artifact. "
|
||||||
|
"Uploaded SageMaker final metrics only.[/yellow]"
|
||||||
|
)
|
||||||
|
if result.registered_model_version:
|
||||||
|
CONSOLE.print(
|
||||||
|
f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] "
|
||||||
|
"([cyan]experiment-latest[/cyan])"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]")
|
||||||
|
CONSOLE.print(
|
||||||
|
f"Retry with [cyan]qc-cli mlflow upload-metrics {job_name}[/cyan]."
|
||||||
|
)
|
||||||
|
raise typer.Exit(1)
|
||||||
|
job_state = st.get_training_job(job_name)
|
||||||
|
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
|
||||||
|
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
|
||||||
|
return
|
||||||
|
time.sleep(poll_interval)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
CONSOLE.print("[yellow]Stopped waiting. The SageMaker training job is still running.[/yellow]")
|
||||||
|
raise typer.Exit(130)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def start(config: str = CONFIG_OPT) -> None:
|
def start(
|
||||||
|
upload_metrics: bool = typer.Option(
|
||||||
|
False,
|
||||||
|
"--upload-metrics",
|
||||||
|
help="Wait for completion, then upload training metrics to MLflow",
|
||||||
|
),
|
||||||
|
poll_interval: int = typer.Option(
|
||||||
|
DEFAULT_POLL_INTERVAL_SECONDS,
|
||||||
|
"--poll-interval",
|
||||||
|
min=1,
|
||||||
|
help="Seconds between status checks when --upload-metrics is used",
|
||||||
|
),
|
||||||
|
config: str = CONFIG_OPT,
|
||||||
|
) -> None:
|
||||||
"""Submit a SageMaker training job."""
|
"""Submit a SageMaker training job."""
|
||||||
cfg = load_cfg(config)
|
cfg = load_cfg(config)
|
||||||
|
|
||||||
|
if upload_metrics and cfg.mlflow.mode is MlflowMode.disabled:
|
||||||
|
CONSOLE.print("[red]--upload-metrics requires MLflow to be enabled in config.yaml.[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
if not cfg.sagemaker.training.image_uri:
|
if not cfg.sagemaker.training.image_uri:
|
||||||
CONSOLE.print("[red]sagemaker.training.image_uri is required in config.yaml.[/red]")
|
CONSOLE.print("[red]sagemaker.training.image_uri is required in config.yaml.[/red]")
|
||||||
CONSOLE.print(
|
CONSOLE.print(
|
||||||
@@ -89,12 +182,20 @@ def start(config: str = CONFIG_OPT) -> None:
|
|||||||
|
|
||||||
st = state_ops.store(config)
|
st = state_ops.store(config)
|
||||||
st.set_last_training_job(job_name)
|
st.set_last_training_job(job_name)
|
||||||
run_id = tracker.start_training_run(
|
try:
|
||||||
training_job,
|
run_id = tracker.start_training_run(
|
||||||
region=cfg.aws.region,
|
training_job,
|
||||||
profile=cfg.aws.profile,
|
region=cfg.aws.region,
|
||||||
role_arn=role_arn,
|
profile=cfg.aws.profile,
|
||||||
)
|
role_arn=role_arn,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
run_id = None
|
||||||
|
CONSOLE.print(f"[yellow]MLflow run creation failed: {e}[/yellow]")
|
||||||
|
CONSOLE.print(
|
||||||
|
"The SageMaker job is still running. Upload metrics after completion with "
|
||||||
|
f"[cyan]qc-cli mlflow upload-metrics {job_name}[/cyan]."
|
||||||
|
)
|
||||||
if run_id:
|
if run_id:
|
||||||
st.update_training_job(job_name, mlflow_run_id=run_id)
|
st.update_training_job(job_name, mlflow_run_id=run_id)
|
||||||
|
|
||||||
@@ -102,7 +203,15 @@ def start(config: str = CONFIG_OPT) -> None:
|
|||||||
if run_id:
|
if run_id:
|
||||||
CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]")
|
CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]")
|
||||||
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
|
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
|
||||||
CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]")
|
if upload_metrics:
|
||||||
|
_wait_and_upload_metrics(
|
||||||
|
job_name=job_name,
|
||||||
|
poll_interval=poll_interval,
|
||||||
|
config_path=config,
|
||||||
|
cfg=cfg,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]")
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
@@ -123,35 +232,7 @@ def status(
|
|||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
|
|
||||||
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
||||||
color = _STATUS_COLOR.get(status.status, "white")
|
_print_training_status(status)
|
||||||
|
|
||||||
CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]")
|
|
||||||
CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]")
|
|
||||||
if status.created:
|
|
||||||
CONSOLE.print(f"Created: {status.created}")
|
|
||||||
if status.model_artifacts:
|
|
||||||
CONSOLE.print(f"Artifacts: {status.model_artifacts}")
|
|
||||||
if status.failure_reason:
|
|
||||||
CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]")
|
|
||||||
|
|
||||||
job_state = st.get_training_job(job_name)
|
|
||||||
run_id = job_state.get("mlflow_run_id")
|
|
||||||
already_registered = job_state.get("registered_model_version")
|
|
||||||
if run_id and not already_registered and status.status in {"Completed", "Failed", "Stopped"}:
|
|
||||||
tracker = _tracker(cfg)
|
|
||||||
version = tracker.finalize_training_run(
|
|
||||||
run_id=str(run_id),
|
|
||||||
training_job_status=status,
|
|
||||||
)
|
|
||||||
updates = {"mlflow_finalized_status": status.status}
|
|
||||||
if version:
|
|
||||||
updates["registered_model_version"] = version
|
|
||||||
st.update_training_job(job_name, **updates)
|
|
||||||
if version:
|
|
||||||
st.set_latest_experiment_model_version(version)
|
|
||||||
CONSOLE.print(f"MLflow model version: [cyan]{version}[/cyan] ([cyan]experiment-latest[/cyan])")
|
|
||||||
if run_id and cfg.mlflow.mode is not MlflowMode.disabled:
|
|
||||||
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
|
|
||||||
|
|
||||||
|
|
||||||
@app.command(name="list")
|
@app.command(name="list")
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
from src.tracking.mlflow import MlflowTracker, NoopTracker, Tracker
|
from src.tracking.mlflow import FinalizeResult, MlflowTracker, NoopTracker, Tracker
|
||||||
|
|
||||||
__all__ = ["MlflowTracker", "NoopTracker", "Tracker"]
|
__all__ = ["FinalizeResult", "MlflowTracker", "NoopTracker", "Tracker"]
|
||||||
|
|||||||
93
src/tracking/metrics.py
Normal file
93
src/tracking/metrics.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
import json
|
||||||
|
import math
|
||||||
|
import tarfile
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import PurePosixPath
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
METRICS_ARTIFACT_NAME = "training_metrics.json"
|
||||||
|
METRICS_SCHEMA_VERSION = 1
|
||||||
|
MAX_METRICS_ARTIFACT_BYTES = 10 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MetricStep:
|
||||||
|
step: int
|
||||||
|
metrics: dict[str, float]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TrainingMetrics:
|
||||||
|
steps: list[MetricStep]
|
||||||
|
summary: dict[str, float]
|
||||||
|
raw: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_training_metrics(data: bytes) -> TrainingMetrics:
|
||||||
|
try:
|
||||||
|
value = json.loads(data)
|
||||||
|
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
|
||||||
|
raise ValueError(f"{METRICS_ARTIFACT_NAME} is not valid JSON") from exc
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
raise ValueError(f"{METRICS_ARTIFACT_NAME} must contain a JSON object")
|
||||||
|
if value.get("schema_version") != METRICS_SCHEMA_VERSION:
|
||||||
|
raise ValueError(f"Unsupported training metrics schema version: {value.get('schema_version')!r}")
|
||||||
|
|
||||||
|
raw_steps = value.get("steps")
|
||||||
|
if not isinstance(raw_steps, list):
|
||||||
|
raise ValueError("training metrics 'steps' must be a list")
|
||||||
|
|
||||||
|
steps: list[MetricStep] = []
|
||||||
|
previous_step: int | None = None
|
||||||
|
for index, raw_step in enumerate(raw_steps):
|
||||||
|
if not isinstance(raw_step, dict):
|
||||||
|
raise ValueError(f"training metrics step {index} must be an object")
|
||||||
|
step = raw_step.get("step")
|
||||||
|
if isinstance(step, bool) or not isinstance(step, int) or step < 0:
|
||||||
|
raise ValueError(f"training metrics step {index} has an invalid 'step'")
|
||||||
|
if previous_step is not None and step <= previous_step:
|
||||||
|
raise ValueError("training metrics steps must be unique and strictly increasing")
|
||||||
|
metrics = _numeric_metrics(raw_step.get("metrics"), f"training metrics step {step}")
|
||||||
|
steps.append(MetricStep(step=step, metrics=metrics))
|
||||||
|
previous_step = step
|
||||||
|
|
||||||
|
summary = _numeric_metrics(value.get("summary", {}), "training metrics summary")
|
||||||
|
return TrainingMetrics(steps=steps, summary=summary, raw=value)
|
||||||
|
|
||||||
|
|
||||||
|
def read_training_metrics_from_tar(archive_path: str) -> bytes | None:
|
||||||
|
with tarfile.open(archive_path, mode="r:*") as archive:
|
||||||
|
matches = [
|
||||||
|
member
|
||||||
|
for member in archive.getmembers()
|
||||||
|
if member.isfile() and PurePosixPath(member.name).name == METRICS_ARTIFACT_NAME
|
||||||
|
]
|
||||||
|
if not matches:
|
||||||
|
return None
|
||||||
|
if len(matches) > 1:
|
||||||
|
raise ValueError(f"Model archive contains multiple {METRICS_ARTIFACT_NAME} files")
|
||||||
|
if matches[0].size > MAX_METRICS_ARTIFACT_BYTES:
|
||||||
|
raise ValueError(
|
||||||
|
f"{METRICS_ARTIFACT_NAME} exceeds the {MAX_METRICS_ARTIFACT_BYTES}-byte size limit"
|
||||||
|
)
|
||||||
|
extracted = archive.extractfile(matches[0])
|
||||||
|
if extracted is None:
|
||||||
|
raise ValueError(f"Could not read {METRICS_ARTIFACT_NAME} from model archive")
|
||||||
|
return extracted.read()
|
||||||
|
|
||||||
|
|
||||||
|
def _numeric_metrics(value: Any, context: str) -> dict[str, float]:
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
raise ValueError(f"{context} 'metrics' must be an object")
|
||||||
|
|
||||||
|
metrics: dict[str, float] = {}
|
||||||
|
for raw_name, raw_value in value.items():
|
||||||
|
if not isinstance(raw_name, str) or not raw_name:
|
||||||
|
raise ValueError(f"{context} contains an invalid metric name")
|
||||||
|
if isinstance(raw_value, bool) or not isinstance(raw_value, int | float):
|
||||||
|
raise ValueError(f"{context} metric '{raw_name}' must be numeric")
|
||||||
|
metric_value = float(raw_value)
|
||||||
|
if not math.isfinite(metric_value):
|
||||||
|
raise ValueError(f"{context} metric '{raw_name}' must be finite")
|
||||||
|
metrics[raw_name] = metric_value
|
||||||
|
return metrics
|
||||||
@@ -1,18 +1,46 @@
|
|||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Protocol
|
from typing import Any, Protocol
|
||||||
|
|
||||||
import mlflow
|
import mlflow
|
||||||
from mlflow.tracking import MlflowClient
|
from mlflow.tracking import MlflowClient
|
||||||
|
|
||||||
from src.aws import mlflow as aws_mlflow
|
from src.aws import s3
|
||||||
|
from src.cloud.mlflow import MlflowTrackingBackend, mlflow_tracking_backend_from_config
|
||||||
from src.config import Config, MlflowMode
|
from src.config import Config, MlflowMode
|
||||||
|
from src.tracking.metrics import METRICS_ARTIFACT_NAME, parse_training_metrics, read_training_metrics_from_tar
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class FinalizeResult:
|
||||||
|
registered_model_version: str | None = None
|
||||||
|
warnings: tuple[str, ...] = ()
|
||||||
|
|
||||||
|
|
||||||
class Tracker(Protocol):
|
class Tracker(Protocol):
|
||||||
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: ...
|
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: ...
|
||||||
|
|
||||||
def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None: ...
|
def finalize_training_run(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
run_id: str | None,
|
||||||
|
training_job_status: Any,
|
||||||
|
region: str,
|
||||||
|
profile: str,
|
||||||
|
command: str,
|
||||||
|
) -> FinalizeResult: ...
|
||||||
|
|
||||||
|
def ensure_training_run(self, job_name: str) -> str: ...
|
||||||
|
|
||||||
|
def upload_training_metrics(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
run_id: str,
|
||||||
|
training_job_status: Any,
|
||||||
|
region: str,
|
||||||
|
profile: str,
|
||||||
|
) -> bool: ...
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -20,8 +48,29 @@ class NoopTracker:
|
|||||||
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
|
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None:
|
def finalize_training_run(
|
||||||
return None
|
self,
|
||||||
|
*,
|
||||||
|
run_id: str | None,
|
||||||
|
training_job_status: Any,
|
||||||
|
region: str,
|
||||||
|
profile: str,
|
||||||
|
command: str,
|
||||||
|
) -> FinalizeResult:
|
||||||
|
return FinalizeResult()
|
||||||
|
|
||||||
|
def ensure_training_run(self, job_name: str) -> str:
|
||||||
|
raise RuntimeError("MLflow is disabled.")
|
||||||
|
|
||||||
|
def upload_training_metrics(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
run_id: str,
|
||||||
|
training_job_status: Any,
|
||||||
|
region: str,
|
||||||
|
profile: str,
|
||||||
|
) -> bool:
|
||||||
|
raise RuntimeError("MLflow is disabled.")
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -30,6 +79,7 @@ class MlflowTracker:
|
|||||||
experiment_name: str
|
experiment_name: str
|
||||||
registered_model_name: str
|
registered_model_name: str
|
||||||
register_trained_models: bool
|
register_trained_models: bool
|
||||||
|
tracking_backend: MlflowTrackingBackend
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, cfg: Config) -> Tracker:
|
def from_config(cls, cfg: Config) -> Tracker:
|
||||||
@@ -42,94 +92,138 @@ class MlflowTracker:
|
|||||||
if not tracking_server_name:
|
if not tracking_server_name:
|
||||||
raise RuntimeError("MLflow tracking server name could not be resolved.")
|
raise RuntimeError("MLflow tracking server name could not be resolved.")
|
||||||
|
|
||||||
tracking_uri = aws_mlflow.get_tracking_server_arn(
|
tracking_backend = mlflow_tracking_backend_from_config(cfg)
|
||||||
cfg.aws.region,
|
|
||||||
cfg.aws.profile,
|
tracking_uri = tracking_backend.get_tracking_uri(tracking_server_name)
|
||||||
tracking_server_name,
|
with tracking_backend.auth_env():
|
||||||
)
|
mlflow.set_tracking_uri(tracking_uri)
|
||||||
mlflow.set_tracking_uri(tracking_uri)
|
mlflow.set_experiment(cfg.mlflow.experiment_name)
|
||||||
mlflow.set_experiment(cfg.mlflow.experiment_name)
|
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
tracking_uri=tracking_uri,
|
tracking_uri=tracking_uri,
|
||||||
experiment_name=cfg.mlflow.experiment_name,
|
experiment_name=cfg.mlflow.experiment_name,
|
||||||
registered_model_name=cfg.mlflow.registered_model_name,
|
registered_model_name=cfg.mlflow.registered_model_name,
|
||||||
register_trained_models=cfg.mlflow.register_trained_models,
|
register_trained_models=cfg.mlflow.register_trained_models,
|
||||||
|
tracking_backend=tracking_backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
|
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
|
||||||
run = mlflow.start_run(run_name=training_job.job_name)
|
with self.tracking_backend.auth_env():
|
||||||
run_id = str(run.info.run_id)
|
with mlflow.start_run(run_name=training_job.job_name) as run:
|
||||||
|
run_id = str(run.info.run_id)
|
||||||
|
self._log_params(
|
||||||
|
self.tracking_backend.training_run_params(
|
||||||
|
training_job,
|
||||||
|
region=region,
|
||||||
|
profile=profile,
|
||||||
|
role_arn=role_arn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()})
|
||||||
|
mlflow.set_tags(
|
||||||
|
{
|
||||||
|
"qc_cli.stage": "experiment",
|
||||||
|
"qc_cli.artifact_kind": "trained_source",
|
||||||
|
"qc_cli.source": self.tracking_backend.provider_name,
|
||||||
|
"qc_cli.command": "train start",
|
||||||
|
**self.tracking_backend.training_run_tags(training_job),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return run_id
|
||||||
|
|
||||||
params = {
|
def finalize_training_run(
|
||||||
"aws.region": region,
|
self,
|
||||||
"aws.profile": profile,
|
*,
|
||||||
"sagemaker.role_arn": role_arn,
|
run_id: str | None,
|
||||||
"sagemaker.job_name": training_job.job_name,
|
training_job_status: Any,
|
||||||
"sagemaker.training_image": training_job.image_uri,
|
region: str,
|
||||||
"sagemaker.instance_type": training_job.instance_type,
|
profile: str,
|
||||||
"sagemaker.instance_count": training_job.instance_count,
|
command: str,
|
||||||
"sagemaker.s3_train_uri": training_job.s3_train_uri,
|
) -> FinalizeResult:
|
||||||
"sagemaker.s3_output_path": training_job.s3_output_path,
|
|
||||||
"sagemaker.entry_point": training_job.entry_point,
|
|
||||||
"sagemaker.source_dir": training_job.source_dir,
|
|
||||||
}
|
|
||||||
self._log_params(params)
|
|
||||||
self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()})
|
|
||||||
mlflow.set_tags(
|
|
||||||
{
|
|
||||||
"qc_cli.stage": "experiment",
|
|
||||||
"qc_cli.artifact_kind": "trained_source",
|
|
||||||
"qc_cli.source": "sagemaker",
|
|
||||||
"qc_cli.command": "train start",
|
|
||||||
"sagemaker.job_name": training_job.job_name,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
mlflow.end_run()
|
|
||||||
return run_id
|
|
||||||
|
|
||||||
def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None:
|
|
||||||
if not run_id:
|
if not run_id:
|
||||||
return None
|
return FinalizeResult()
|
||||||
|
|
||||||
with mlflow.start_run(run_id=run_id):
|
with self.tracking_backend.auth_env():
|
||||||
self._log_params(
|
with mlflow.start_run(run_id=run_id):
|
||||||
{
|
self._log_params(self.tracking_backend.training_status_params(training_job_status))
|
||||||
"sagemaker.training_status": training_job_status.status,
|
self._log_final_metrics(training_job_status.raw)
|
||||||
"sagemaker.created_at": training_job_status.created,
|
mlflow.set_tag("qc_cli.command", command)
|
||||||
"sagemaker.modified_at": training_job_status.modified,
|
|
||||||
"sagemaker.model_artifacts": training_job_status.model_artifacts,
|
|
||||||
"sagemaker.failure_reason": training_job_status.failure_reason,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
self._log_final_metrics(training_job_status.raw)
|
|
||||||
mlflow.set_tag("qc_cli.command", "train status")
|
|
||||||
|
|
||||||
if training_job_status.status != "Completed" or not training_job_status.model_artifacts:
|
if training_job_status.status != "Completed" or not training_job_status.model_artifacts:
|
||||||
mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status)
|
mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status)
|
||||||
return None
|
return FinalizeResult()
|
||||||
|
|
||||||
if not self.register_trained_models:
|
if not self.register_trained_models:
|
||||||
return None
|
return FinalizeResult()
|
||||||
|
|
||||||
|
client = MlflowClient()
|
||||||
|
self._ensure_registered_model(client, self.registered_model_name)
|
||||||
|
version = client.create_model_version(
|
||||||
|
name=self.registered_model_name,
|
||||||
|
source=training_job_status.model_artifacts,
|
||||||
|
run_id=run_id,
|
||||||
|
tags={
|
||||||
|
"qc_cli.stage": "experiment",
|
||||||
|
"qc_cli.artifact_kind": "trained_source",
|
||||||
|
"qc_cli.source": self.tracking_backend.provider_name,
|
||||||
|
**self.tracking_backend.model_version_tags(training_job_status),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
version_number = str(version.version)
|
||||||
|
client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number)
|
||||||
|
mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name)
|
||||||
|
mlflow.set_tag("qc_cli.registered_model_version", version_number)
|
||||||
|
return FinalizeResult(registered_model_version=version_number)
|
||||||
|
|
||||||
|
def ensure_training_run(self, job_name: str) -> str:
|
||||||
|
with self.tracking_backend.auth_env():
|
||||||
client = MlflowClient()
|
client = MlflowClient()
|
||||||
self._ensure_registered_model(client, self.registered_model_name)
|
experiment = client.get_experiment_by_name(self.experiment_name)
|
||||||
version = client.create_model_version(
|
if experiment is None:
|
||||||
name=self.registered_model_name,
|
experiment_id = mlflow.create_experiment(self.experiment_name)
|
||||||
source=training_job_status.model_artifacts,
|
else:
|
||||||
run_id=run_id,
|
experiment_id = experiment.experiment_id
|
||||||
|
|
||||||
|
for run in client.search_runs([experiment_id], max_results=1000):
|
||||||
|
if run.data.tags.get("sagemaker.job_name") == job_name:
|
||||||
|
return str(run.info.run_id)
|
||||||
|
|
||||||
|
run = client.create_run(
|
||||||
|
experiment_id,
|
||||||
|
run_name=job_name,
|
||||||
tags={
|
tags={
|
||||||
"qc_cli.stage": "experiment",
|
"qc_cli.stage": "experiment",
|
||||||
"qc_cli.artifact_kind": "trained_source",
|
"qc_cli.artifact_kind": "trained_source",
|
||||||
"qc_cli.source": "sagemaker",
|
"qc_cli.source": self.tracking_backend.provider_name,
|
||||||
"sagemaker.job_name": training_job_status.name,
|
"qc_cli.command": "mlflow upload-metrics",
|
||||||
|
"sagemaker.job_name": job_name,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
version_number = str(version.version)
|
return str(run.info.run_id)
|
||||||
client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number)
|
|
||||||
mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name)
|
def upload_training_metrics(
|
||||||
mlflow.set_tag("qc_cli.registered_model_version", version_number)
|
self,
|
||||||
return version_number
|
*,
|
||||||
|
run_id: str,
|
||||||
|
training_job_status: Any,
|
||||||
|
region: str,
|
||||||
|
profile: str,
|
||||||
|
) -> bool:
|
||||||
|
if not training_job_status.model_artifacts:
|
||||||
|
raise ValueError(f"Training job '{training_job_status.name}' has no model artifacts.")
|
||||||
|
|
||||||
|
with self.tracking_backend.auth_env():
|
||||||
|
with mlflow.start_run(run_id=run_id):
|
||||||
|
self._log_params(self.tracking_backend.training_status_params(training_job_status))
|
||||||
|
self._log_final_metrics(training_job_status.raw)
|
||||||
|
history_uploaded = self._log_training_metrics(
|
||||||
|
training_job_status.model_artifacts,
|
||||||
|
region=region,
|
||||||
|
profile=profile,
|
||||||
|
)
|
||||||
|
mlflow.set_tag("qc_cli.command", "mlflow upload-metrics")
|
||||||
|
mlflow.set_tag("qc_cli.metrics_history_uploaded", str(history_uploaded).lower())
|
||||||
|
return history_uploaded
|
||||||
|
|
||||||
def _log_params(self, params: dict[str, Any]) -> None:
|
def _log_params(self, params: dict[str, Any]) -> None:
|
||||||
cleaned = {key: str(value) for key, value in params.items() if value is not None}
|
cleaned = {key: str(value) for key, value in params.items() if value is not None}
|
||||||
@@ -146,6 +240,26 @@ class MlflowTracker:
|
|||||||
if metrics:
|
if metrics:
|
||||||
mlflow.log_metrics(metrics)
|
mlflow.log_metrics(metrics)
|
||||||
|
|
||||||
|
def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> bool:
|
||||||
|
with tempfile.TemporaryDirectory(prefix="qc-cli-metrics-") as temp_dir:
|
||||||
|
archive_path = s3.download_file(
|
||||||
|
region,
|
||||||
|
profile,
|
||||||
|
model_artifacts,
|
||||||
|
os.path.join(temp_dir, "model.tar.gz"),
|
||||||
|
)
|
||||||
|
metrics_data = read_training_metrics_from_tar(archive_path)
|
||||||
|
if metrics_data is None:
|
||||||
|
return False
|
||||||
|
metrics = parse_training_metrics(metrics_data)
|
||||||
|
for metric_step in metrics.steps:
|
||||||
|
if metric_step.metrics:
|
||||||
|
mlflow.log_metrics(metric_step.metrics, step=metric_step.step)
|
||||||
|
if metrics.summary:
|
||||||
|
mlflow.log_metrics(metrics.summary)
|
||||||
|
mlflow.log_dict(metrics.raw, METRICS_ARTIFACT_NAME)
|
||||||
|
return True
|
||||||
|
|
||||||
def _ensure_registered_model(self, client: MlflowClient, name: str) -> None:
|
def _ensure_registered_model(self, client: MlflowClient, name: str) -> None:
|
||||||
try:
|
try:
|
||||||
client.get_registered_model(name)
|
client.get_registered_model(name)
|
||||||
|
|||||||
75
src/tracking/upload.py
Normal file
75
src/tracking/upload.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from src import state as state_ops
|
||||||
|
from src.aws import sagemaker as sm_ops
|
||||||
|
from src.config import Config, MlflowMode
|
||||||
|
from src.tracking.mlflow import MlflowTracker
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MetricsUploadResult:
|
||||||
|
run_id: str
|
||||||
|
registered_model_version: str | None = None
|
||||||
|
metrics_history_uploaded: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
def upload_training_metrics(
|
||||||
|
*,
|
||||||
|
job_name: str,
|
||||||
|
config_path: str,
|
||||||
|
cfg: Config,
|
||||||
|
force: bool = False,
|
||||||
|
) -> MetricsUploadResult:
|
||||||
|
if cfg.mlflow.mode is MlflowMode.disabled:
|
||||||
|
raise RuntimeError("MLflow is disabled in config.yaml.")
|
||||||
|
|
||||||
|
st = state_ops.store(config_path)
|
||||||
|
job_state = st.get_training_job(job_name)
|
||||||
|
if job_state.get("mlflow_metrics_uploaded") and not force:
|
||||||
|
return MetricsUploadResult(
|
||||||
|
run_id=str(job_state.get("mlflow_run_id") or ""),
|
||||||
|
registered_model_version=(
|
||||||
|
str(job_state["registered_model_version"])
|
||||||
|
if job_state.get("registered_model_version")
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
metrics_history_uploaded=bool(job_state.get("mlflow_metrics_history_uploaded", True)),
|
||||||
|
)
|
||||||
|
|
||||||
|
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
||||||
|
if status.status != "Completed":
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Training job {job_name} is {status.status}; metrics can be uploaded only after completion."
|
||||||
|
)
|
||||||
|
|
||||||
|
tracker = MlflowTracker.from_config(cfg)
|
||||||
|
run_id = str(job_state.get("mlflow_run_id") or tracker.ensure_training_run(job_name))
|
||||||
|
st.update_training_job(job_name, mlflow_run_id=run_id)
|
||||||
|
metrics_history_uploaded = tracker.upload_training_metrics(
|
||||||
|
run_id=run_id,
|
||||||
|
training_job_status=status,
|
||||||
|
region=cfg.aws.region,
|
||||||
|
profile=cfg.aws.profile,
|
||||||
|
)
|
||||||
|
finalized = tracker.finalize_training_run(
|
||||||
|
run_id=run_id,
|
||||||
|
training_job_status=status,
|
||||||
|
region=cfg.aws.region,
|
||||||
|
profile=cfg.aws.profile,
|
||||||
|
command="mlflow upload-metrics",
|
||||||
|
)
|
||||||
|
updates = {
|
||||||
|
"mlflow_metrics_uploaded": True,
|
||||||
|
"mlflow_metrics_history_uploaded": metrics_history_uploaded,
|
||||||
|
"mlflow_finalized_status": status.status,
|
||||||
|
}
|
||||||
|
if finalized.registered_model_version:
|
||||||
|
updates["registered_model_version"] = finalized.registered_model_version
|
||||||
|
st.update_training_job(job_name, **updates)
|
||||||
|
if finalized.registered_model_version:
|
||||||
|
st.set_latest_experiment_model_version(finalized.registered_model_version)
|
||||||
|
return MetricsUploadResult(
|
||||||
|
run_id=run_id,
|
||||||
|
registered_model_version=finalized.registered_model_version,
|
||||||
|
metrics_history_uploaded=metrics_history_uploaded,
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user