another update
This commit is contained in:
29
README.md
29
README.md
@@ -128,9 +128,14 @@ qc-cli init --force Overwrite an existing config file
|
||||
### `mlflow`
|
||||
|
||||
```
|
||||
qc-cli mlflow open Open a presigned MLflow UI URL in a browser
|
||||
qc-cli mlflow open Open a presigned MLflow UI URL
|
||||
qc-cli mlflow upload-metrics [job-name] Upload completed training metrics
|
||||
```
|
||||
|
||||
`mlflow upload-metrics` defaults to the last submitted training job. It creates or recovers the job's MLflow run,
|
||||
imports `training_metrics.json` from the SageMaker model artifact, and records successful upload in `.qc-cli.json`.
|
||||
Use `--force` to upload the metrics again.
|
||||
|
||||
### `infra`
|
||||
|
||||
```
|
||||
@@ -163,7 +168,7 @@ Uploads use `s3.bucket` and `s3.data_prefix` from `config.yaml`. File uploads de
|
||||
|
||||
```
|
||||
qc-cli train start Submit a SageMaker training job
|
||||
qc-cli train start --wait Submit, wait, and finalize MLflow tracking
|
||||
qc-cli train start --upload-metrics Submit, wait, and upload metrics
|
||||
qc-cli train status [job-name] Show job status; defaults to the last submitted job
|
||||
qc-cli train list List recent training jobs
|
||||
qc-cli train list --limit 3 Show a custom number of recent jobs
|
||||
@@ -171,7 +176,7 @@ qc-cli train list --limit 3 Show a custom number of recent jobs
|
||||
|
||||
`train start` uses `s3://<bucket>/<data_prefix>/` as the training channel and writes outputs under `s3://<bucket>/<model_prefix>/`. If `sagemaker.training.source_dir` is set, the CLI packages that directory, uploads it beside the job output prefix, and passes `sagemaker_program`/`sagemaker_submit_directory` to the SageMaker container.
|
||||
|
||||
`train start --wait` checks SageMaker every 30 seconds by default. Use `--poll-interval <seconds>` to choose another positive interval. Stopping the local command does not stop the SageMaker job.
|
||||
`train start --upload-metrics` checks SageMaker every 30 seconds by default, then uploads metrics after completion. Use `--poll-interval <seconds>` to choose another positive interval. Stopping the local command does not stop the SageMaker job.
|
||||
|
||||
The expected output artifact is SageMaker’s `model.tar.gz`, normally containing the trained model file your container writes to `/opt/ml/model`.
|
||||
|
||||
@@ -219,15 +224,19 @@ The CLI uses neutral experiment naming for trained artifacts and reserves releas
|
||||
Current behavior:
|
||||
|
||||
1. `qc-cli train start` submits a SageMaker training job.
|
||||
2. `qc-cli train status` or `qc-cli train start --wait` finalizes the MLflow run after the job reaches a terminal state. `--wait` polls every 30 seconds by default.
|
||||
3. If the job completed and `mlflow.register_trained_models` is enabled, the SageMaker `model.tar.gz` is registered as a new MLflow model version with:
|
||||
2. `qc-cli train status` finalizes the MLflow run and registers completed model artifacts.
|
||||
3. `qc-cli train start --upload-metrics` polls every 30 seconds by default, then uploads per-epoch metrics after completion.
|
||||
4. `qc-cli mlflow upload-metrics [job-name]` uploads or retries metrics for an existing completed job.
|
||||
5. If the job completed and `mlflow.register_trained_models` is enabled, the SageMaker `model.tar.gz` is registered as a new MLflow model version with:
|
||||
- `qc_cli.stage=experiment`
|
||||
- `qc_cli.artifact_kind=trained_source`
|
||||
- `qc_cli.source=sagemaker`
|
||||
4. The MLflow alias `experiment-latest` points at the most recently registered experiment version.
|
||||
5. AI Hub upload commands create deployable derived artifacts from a trained-source experiment or local ONNX model.
|
||||
6. The MLflow alias `experiment-latest` points at the most recently registered experiment version.
|
||||
7. AI Hub upload commands create deployable derived artifacts from a trained-source experiment or local ONNX model.
|
||||
|
||||
Training scripts can include a `training_metrics.json` file in the SageMaker model directory. During finalization, the CLI logs its ordered metrics to the associated MLflow run using each epoch as the MLflow step and stores the JSON as a run artifact:
|
||||
Training scripts can include a `training_metrics.json` file in the SageMaker model directory. The explicit metrics
|
||||
upload command logs its ordered metrics to the associated MLflow run using each epoch as the MLflow step and stores
|
||||
the JSON as a run artifact:
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -239,7 +248,9 @@ Training scripts can include a `training_metrics.json` file in the SageMaker mod
|
||||
}
|
||||
```
|
||||
|
||||
Metric names must be non-empty strings, values must be finite numbers, and steps must be non-negative, unique, and strictly increasing. Missing or malformed metrics produce a warning but do not block model registration.
|
||||
Metric names must be non-empty strings, values must be finite numbers, and steps must be non-negative, unique, and
|
||||
strictly increasing. A missing or malformed metrics artifact fails the upload command without affecting the trained
|
||||
model or model registration.
|
||||
|
||||
Future release aliases such as `v1` or `production` can point at a selected deployable artifact.
|
||||
|
||||
|
||||
@@ -156,11 +156,17 @@ qc-cli train status qc-cli-YYYYMMDD-HHMMSS
|
||||
To submit the job, wait for completion, and automatically import metrics and register the model, run:
|
||||
|
||||
```bash
|
||||
qc-cli train start --wait
|
||||
qc-cli train start --upload-metrics
|
||||
```
|
||||
|
||||
The default polling interval is 30 seconds. It can be changed with `--poll-interval <seconds>`.
|
||||
|
||||
The metrics can be also submitted using:
|
||||
|
||||
```bash
|
||||
qc-cli mlflow upload-metrics
|
||||
```
|
||||
|
||||
## SageMaker Outputs
|
||||
|
||||
When the job completes, SageMaker packages the files written under `/opt/ml/model` into `model.tar.gz`.
|
||||
@@ -176,7 +182,9 @@ training_metrics.json
|
||||
|
||||
The archive is stored under the configured `s3.model_prefix`.
|
||||
|
||||
During MLflow finalization, `training_metrics.json` provides per-epoch training and validation losses, precision, recall, mAP@0.50, mAP@0.50:0.95, and learning rates. For object detection, mAP and precision/recall are more meaningful than classification accuracy when assessing model quality.
|
||||
The `mlflow upload-metrics` command imports `training_metrics.json`, which provides per-epoch training and validation
|
||||
losses, precision, recall, mAP@0.50, mAP@0.50:0.95, and learning rates. For object detection, mAP and precision/recall
|
||||
are more meaningful than classification accuracy when assessing model quality.
|
||||
|
||||
## 6. Configure Qualcomm AI Hub
|
||||
|
||||
|
||||
@@ -2,8 +2,11 @@ import webbrowser
|
||||
|
||||
import typer
|
||||
|
||||
from src import state as state_ops
|
||||
from src.aws import mlflow as aws_mlflow
|
||||
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
||||
from src.config import MlflowMode
|
||||
from src.tracking.upload import upload_training_metrics
|
||||
|
||||
app = typer.Typer(help="Manage MLflow tracking server access")
|
||||
|
||||
@@ -39,3 +42,43 @@ def open_mlflow(config: str = CONFIG_OPT) -> None:
|
||||
CONSOLE.print("[green]✓[/green] Opened MLflow UI in your browser.")
|
||||
else:
|
||||
CONSOLE.print("[yellow]Could not open a browser automatically. Open the URL above manually.[/yellow]")
|
||||
|
||||
|
||||
@app.command(name="upload-metrics")
|
||||
def upload_metrics(
|
||||
job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"),
|
||||
force: bool = typer.Option(False, "--force", help="Upload again even if this job is marked as uploaded"),
|
||||
config: str = CONFIG_OPT,
|
||||
) -> None:
|
||||
"""Upload a completed training job's metric history to MLflow."""
|
||||
cfg = load_cfg(config)
|
||||
if cfg.mlflow.mode is MlflowMode.disabled:
|
||||
CONSOLE.print("[red]MLflow is disabled in config.yaml.[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
st = state_ops.store(config)
|
||||
if not job_name:
|
||||
job_name = st.get_last_training_job()
|
||||
if not job_name:
|
||||
CONSOLE.print(
|
||||
"[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]"
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
if st.get_training_job(job_name).get("mlflow_metrics_uploaded") and not force:
|
||||
CONSOLE.print(f"[green]✓[/green] Metrics already uploaded for [cyan]{job_name}[/cyan].")
|
||||
return
|
||||
|
||||
try:
|
||||
run_id = upload_training_metrics(
|
||||
job_name=job_name,
|
||||
config_path=config,
|
||||
cfg=cfg,
|
||||
force=force,
|
||||
)
|
||||
except Exception as e:
|
||||
CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
CONSOLE.print(f"[green]✓[/green] Uploaded training metrics for [cyan]{job_name}[/cyan].")
|
||||
CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]")
|
||||
|
||||
@@ -12,6 +12,7 @@ from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
||||
from src.config import Config, MlflowMode
|
||||
from src.infra.state import read_infra_state
|
||||
from src.tracking.mlflow import MlflowTracker
|
||||
from src.tracking.upload import upload_training_metrics
|
||||
|
||||
app = typer.Typer(help="Manage SageMaker training jobs")
|
||||
|
||||
@@ -102,7 +103,7 @@ def _finalize_terminal_job(
|
||||
)
|
||||
|
||||
|
||||
def _wait_for_training_job(
|
||||
def _wait_and_upload_metrics(
|
||||
*,
|
||||
job_name: str,
|
||||
poll_interval: int,
|
||||
@@ -123,12 +124,21 @@ def _wait_for_training_job(
|
||||
previous_status = training_status.status
|
||||
if training_status.status in _TERMINAL_STATUSES:
|
||||
_print_training_status(training_status)
|
||||
_finalize_terminal_job(
|
||||
config_path=config_path,
|
||||
cfg=cfg,
|
||||
status=training_status,
|
||||
command="train start --wait",
|
||||
)
|
||||
if training_status.status != "Completed":
|
||||
raise typer.Exit(1)
|
||||
try:
|
||||
run_id = upload_training_metrics(
|
||||
job_name=job_name,
|
||||
config_path=config_path,
|
||||
cfg=cfg,
|
||||
)
|
||||
CONSOLE.print(f"[green]✓[/green] Uploaded training metrics to MLflow run [cyan]{run_id}[/cyan].")
|
||||
except Exception as e:
|
||||
CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]")
|
||||
CONSOLE.print(
|
||||
f"Retry with [cyan]qc-cli mlflow upload-metrics {job_name}[/cyan]."
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
job_state = st.get_training_job(job_name)
|
||||
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
|
||||
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
|
||||
@@ -141,18 +151,26 @@ def _wait_for_training_job(
|
||||
|
||||
@app.command()
|
||||
def start(
|
||||
wait: bool = typer.Option(False, "--wait", help="Wait for completion and finalize MLflow tracking"),
|
||||
upload_metrics: bool = typer.Option(
|
||||
False,
|
||||
"--upload-metrics",
|
||||
help="Wait for completion, then upload training metrics to MLflow",
|
||||
),
|
||||
poll_interval: int = typer.Option(
|
||||
DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
"--poll-interval",
|
||||
min=1,
|
||||
help="Seconds between status checks when --wait is used",
|
||||
help="Seconds between status checks when --upload-metrics is used",
|
||||
),
|
||||
config: str = CONFIG_OPT,
|
||||
) -> None:
|
||||
"""Submit a SageMaker training job."""
|
||||
cfg = load_cfg(config)
|
||||
|
||||
if upload_metrics and cfg.mlflow.mode is MlflowMode.disabled:
|
||||
CONSOLE.print("[red]--upload-metrics requires MLflow to be enabled in config.yaml.[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if not cfg.sagemaker.training.image_uri:
|
||||
CONSOLE.print("[red]sagemaker.training.image_uri is required in config.yaml.[/red]")
|
||||
CONSOLE.print(
|
||||
@@ -189,12 +207,20 @@ def start(
|
||||
|
||||
st = state_ops.store(config)
|
||||
st.set_last_training_job(job_name)
|
||||
run_id = tracker.start_training_run(
|
||||
training_job,
|
||||
region=cfg.aws.region,
|
||||
profile=cfg.aws.profile,
|
||||
role_arn=role_arn,
|
||||
)
|
||||
try:
|
||||
run_id = tracker.start_training_run(
|
||||
training_job,
|
||||
region=cfg.aws.region,
|
||||
profile=cfg.aws.profile,
|
||||
role_arn=role_arn,
|
||||
)
|
||||
except Exception as e:
|
||||
run_id = None
|
||||
CONSOLE.print(f"[yellow]MLflow run creation failed: {e}[/yellow]")
|
||||
CONSOLE.print(
|
||||
"The SageMaker job is still running. Upload metrics after completion with "
|
||||
f"[cyan]qc-cli mlflow upload-metrics {job_name}[/cyan]."
|
||||
)
|
||||
if run_id:
|
||||
st.update_training_job(job_name, mlflow_run_id=run_id)
|
||||
|
||||
@@ -202,8 +228,8 @@ def start(
|
||||
if run_id:
|
||||
CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]")
|
||||
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
|
||||
if wait:
|
||||
_wait_for_training_job(
|
||||
if upload_metrics:
|
||||
_wait_and_upload_metrics(
|
||||
job_name=job_name,
|
||||
poll_interval=poll_interval,
|
||||
config_path=config,
|
||||
|
||||
@@ -31,6 +31,17 @@ class Tracker(Protocol):
|
||||
command: str,
|
||||
) -> FinalizeResult: ...
|
||||
|
||||
def ensure_training_run(self, job_name: str) -> str: ...
|
||||
|
||||
def upload_training_metrics(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
training_job_status: Any,
|
||||
region: str,
|
||||
profile: str,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NoopTracker:
|
||||
@@ -48,6 +59,19 @@ class NoopTracker:
|
||||
) -> FinalizeResult:
|
||||
return FinalizeResult()
|
||||
|
||||
def ensure_training_run(self, job_name: str) -> str:
|
||||
raise RuntimeError("MLflow is disabled.")
|
||||
|
||||
def upload_training_metrics(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
training_job_status: Any,
|
||||
region: str,
|
||||
profile: str,
|
||||
) -> None:
|
||||
raise RuntimeError("MLflow is disabled.")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MlflowTracker:
|
||||
@@ -73,7 +97,6 @@ class MlflowTracker:
|
||||
tracking_server_name,
|
||||
)
|
||||
mlflow.set_tracking_uri(tracking_uri)
|
||||
mlflow.set_experiment(cfg.mlflow.experiment_name)
|
||||
|
||||
return cls(
|
||||
tracking_uri=tracking_uri,
|
||||
@@ -83,34 +106,33 @@ class MlflowTracker:
|
||||
)
|
||||
|
||||
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
|
||||
run = mlflow.start_run(run_name=training_job.job_name)
|
||||
run_id = str(run.info.run_id)
|
||||
|
||||
params = {
|
||||
"aws.region": region,
|
||||
"aws.profile": profile,
|
||||
"sagemaker.role_arn": role_arn,
|
||||
"sagemaker.job_name": training_job.job_name,
|
||||
"sagemaker.training_image": training_job.image_uri,
|
||||
"sagemaker.instance_type": training_job.instance_type,
|
||||
"sagemaker.instance_count": training_job.instance_count,
|
||||
"sagemaker.s3_train_uri": training_job.s3_train_uri,
|
||||
"sagemaker.s3_output_path": training_job.s3_output_path,
|
||||
"sagemaker.entry_point": training_job.entry_point,
|
||||
"sagemaker.source_dir": training_job.source_dir,
|
||||
}
|
||||
self._log_params(params)
|
||||
self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()})
|
||||
mlflow.set_tags(
|
||||
{
|
||||
"qc_cli.stage": "experiment",
|
||||
"qc_cli.artifact_kind": "trained_source",
|
||||
"qc_cli.source": "sagemaker",
|
||||
"qc_cli.command": "train start",
|
||||
mlflow.set_experiment(self.experiment_name)
|
||||
with mlflow.start_run(run_name=training_job.job_name) as run:
|
||||
run_id = str(run.info.run_id)
|
||||
params = {
|
||||
"aws.region": region,
|
||||
"aws.profile": profile,
|
||||
"sagemaker.role_arn": role_arn,
|
||||
"sagemaker.job_name": training_job.job_name,
|
||||
"sagemaker.training_image": training_job.image_uri,
|
||||
"sagemaker.instance_type": training_job.instance_type,
|
||||
"sagemaker.instance_count": training_job.instance_count,
|
||||
"sagemaker.s3_train_uri": training_job.s3_train_uri,
|
||||
"sagemaker.s3_output_path": training_job.s3_output_path,
|
||||
"sagemaker.entry_point": training_job.entry_point,
|
||||
"sagemaker.source_dir": training_job.source_dir,
|
||||
}
|
||||
)
|
||||
mlflow.end_run()
|
||||
self._log_params(params)
|
||||
self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()})
|
||||
mlflow.set_tags(
|
||||
{
|
||||
"qc_cli.stage": "experiment",
|
||||
"qc_cli.artifact_kind": "trained_source",
|
||||
"qc_cli.source": "sagemaker",
|
||||
"qc_cli.command": "train start",
|
||||
"sagemaker.job_name": training_job.job_name,
|
||||
}
|
||||
)
|
||||
return run_id
|
||||
|
||||
def finalize_training_run(
|
||||
@@ -125,7 +147,6 @@ class MlflowTracker:
|
||||
if not run_id:
|
||||
return FinalizeResult()
|
||||
|
||||
warnings: list[str] = []
|
||||
with mlflow.start_run(run_id=run_id):
|
||||
self._log_params(
|
||||
{
|
||||
@@ -137,22 +158,14 @@ class MlflowTracker:
|
||||
}
|
||||
)
|
||||
self._log_final_metrics(training_job_status.raw)
|
||||
if training_job_status.status == "Completed" and training_job_status.model_artifacts:
|
||||
warnings.extend(
|
||||
self._log_training_metrics(
|
||||
training_job_status.model_artifacts,
|
||||
region=region,
|
||||
profile=profile,
|
||||
)
|
||||
)
|
||||
mlflow.set_tag("qc_cli.command", command)
|
||||
|
||||
if training_job_status.status != "Completed" or not training_job_status.model_artifacts:
|
||||
mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status)
|
||||
return FinalizeResult(warnings=tuple(warnings))
|
||||
return FinalizeResult()
|
||||
|
||||
if not self.register_trained_models:
|
||||
return FinalizeResult(warnings=tuple(warnings))
|
||||
return FinalizeResult()
|
||||
|
||||
client = MlflowClient()
|
||||
self._ensure_registered_model(client, self.registered_model_name)
|
||||
@@ -171,7 +184,61 @@ class MlflowTracker:
|
||||
client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number)
|
||||
mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name)
|
||||
mlflow.set_tag("qc_cli.registered_model_version", version_number)
|
||||
return FinalizeResult(registered_model_version=version_number, warnings=tuple(warnings))
|
||||
return FinalizeResult(registered_model_version=version_number)
|
||||
|
||||
def ensure_training_run(self, job_name: str) -> str:
|
||||
client = MlflowClient()
|
||||
experiment = client.get_experiment_by_name(self.experiment_name)
|
||||
if experiment is None:
|
||||
experiment_id = mlflow.create_experiment(self.experiment_name)
|
||||
else:
|
||||
experiment_id = experiment.experiment_id
|
||||
|
||||
for run in client.search_runs([experiment_id], max_results=1000):
|
||||
if run.data.tags.get("sagemaker.job_name") == job_name:
|
||||
return str(run.info.run_id)
|
||||
|
||||
run = client.create_run(
|
||||
experiment_id,
|
||||
run_name=job_name,
|
||||
tags={
|
||||
"qc_cli.stage": "experiment",
|
||||
"qc_cli.artifact_kind": "trained_source",
|
||||
"qc_cli.source": "sagemaker",
|
||||
"qc_cli.command": "mlflow upload-metrics",
|
||||
"sagemaker.job_name": job_name,
|
||||
},
|
||||
)
|
||||
return str(run.info.run_id)
|
||||
|
||||
def upload_training_metrics(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
training_job_status: Any,
|
||||
region: str,
|
||||
profile: str,
|
||||
) -> None:
|
||||
if not training_job_status.model_artifacts:
|
||||
raise ValueError(f"Training job '{training_job_status.name}' has no model artifacts.")
|
||||
|
||||
with mlflow.start_run(run_id=run_id):
|
||||
self._log_params(
|
||||
{
|
||||
"sagemaker.training_status": training_job_status.status,
|
||||
"sagemaker.created_at": training_job_status.created,
|
||||
"sagemaker.modified_at": training_job_status.modified,
|
||||
"sagemaker.model_artifacts": training_job_status.model_artifacts,
|
||||
"sagemaker.failure_reason": training_job_status.failure_reason,
|
||||
}
|
||||
)
|
||||
self._log_final_metrics(training_job_status.raw)
|
||||
self._log_training_metrics(
|
||||
training_job_status.model_artifacts,
|
||||
region=region,
|
||||
profile=profile,
|
||||
)
|
||||
mlflow.set_tag("qc_cli.command", "mlflow upload-metrics")
|
||||
|
||||
def _log_params(self, params: dict[str, Any]) -> None:
|
||||
cleaned = {key: str(value) for key, value in params.items() if value is not None}
|
||||
@@ -188,28 +255,24 @@ class MlflowTracker:
|
||||
if metrics:
|
||||
mlflow.log_metrics(metrics)
|
||||
|
||||
def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> list[str]:
|
||||
try:
|
||||
with tempfile.TemporaryDirectory(prefix="qc-cli-metrics-") as temp_dir:
|
||||
archive_path = s3.download_file(
|
||||
region,
|
||||
profile,
|
||||
model_artifacts,
|
||||
os.path.join(temp_dir, "model.tar.gz"),
|
||||
)
|
||||
metrics_data = read_training_metrics_from_tar(archive_path)
|
||||
if metrics_data is None:
|
||||
return [f"No {METRICS_ARTIFACT_NAME} found in the SageMaker model artifact."]
|
||||
metrics = parse_training_metrics(metrics_data)
|
||||
for metric_step in metrics.steps:
|
||||
if metric_step.metrics:
|
||||
mlflow.log_metrics(metric_step.metrics, step=metric_step.step)
|
||||
if metrics.summary:
|
||||
mlflow.log_metrics(metrics.summary)
|
||||
mlflow.log_dict(metrics.raw, METRICS_ARTIFACT_NAME)
|
||||
except Exception as exc:
|
||||
return [f"Could not import training metrics: {exc}"]
|
||||
return []
|
||||
def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> None:
|
||||
with tempfile.TemporaryDirectory(prefix="qc-cli-metrics-") as temp_dir:
|
||||
archive_path = s3.download_file(
|
||||
region,
|
||||
profile,
|
||||
model_artifacts,
|
||||
os.path.join(temp_dir, "model.tar.gz"),
|
||||
)
|
||||
metrics_data = read_training_metrics_from_tar(archive_path)
|
||||
if metrics_data is None:
|
||||
raise ValueError(f"No {METRICS_ARTIFACT_NAME} found in the SageMaker model artifact.")
|
||||
metrics = parse_training_metrics(metrics_data)
|
||||
for metric_step in metrics.steps:
|
||||
if metric_step.metrics:
|
||||
mlflow.log_metrics(metric_step.metrics, step=metric_step.step)
|
||||
if metrics.summary:
|
||||
mlflow.log_metrics(metrics.summary)
|
||||
mlflow.log_dict(metrics.raw, METRICS_ARTIFACT_NAME)
|
||||
|
||||
def _ensure_registered_model(self, client: MlflowClient, name: str) -> None:
|
||||
try:
|
||||
|
||||
38
src/tracking/upload.py
Normal file
38
src/tracking/upload.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from src import state as state_ops
|
||||
from src.aws import sagemaker as sm_ops
|
||||
from src.config import Config, MlflowMode
|
||||
from src.tracking.mlflow import MlflowTracker
|
||||
|
||||
|
||||
def upload_training_metrics(
|
||||
*,
|
||||
job_name: str,
|
||||
config_path: str,
|
||||
cfg: Config,
|
||||
force: bool = False,
|
||||
) -> str:
|
||||
if cfg.mlflow.mode is MlflowMode.disabled:
|
||||
raise RuntimeError("MLflow is disabled in config.yaml.")
|
||||
|
||||
st = state_ops.store(config_path)
|
||||
job_state = st.get_training_job(job_name)
|
||||
if job_state.get("mlflow_metrics_uploaded") and not force:
|
||||
return str(job_state.get("mlflow_run_id") or "")
|
||||
|
||||
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
||||
if status.status != "Completed":
|
||||
raise RuntimeError(
|
||||
f"Training job {job_name} is {status.status}; metrics can be uploaded only after completion."
|
||||
)
|
||||
|
||||
tracker = MlflowTracker.from_config(cfg)
|
||||
run_id = str(job_state.get("mlflow_run_id") or tracker.ensure_training_run(job_name))
|
||||
st.update_training_job(job_name, mlflow_run_id=run_id)
|
||||
tracker.upload_training_metrics(
|
||||
run_id=run_id,
|
||||
training_job_status=status,
|
||||
region=cfg.aws.region,
|
||||
profile=cfg.aws.profile,
|
||||
)
|
||||
st.update_training_job(job_name, mlflow_metrics_uploaded=True)
|
||||
return run_id
|
||||
Reference in New Issue
Block a user