another update

This commit is contained in:
2026-06-12 12:17:02 -04:00
parent 53e886a535
commit 5211d0af14
6 changed files with 278 additions and 89 deletions

View File

@@ -128,9 +128,14 @@ qc-cli init --force Overwrite an existing config file
### `mlflow` ### `mlflow`
``` ```
qc-cli mlflow open Open a presigned MLflow UI URL in a browser qc-cli mlflow open Open a presigned MLflow UI URL
qc-cli mlflow upload-metrics [job-name] Upload completed training metrics
``` ```
`mlflow upload-metrics` defaults to the last submitted training job. It creates or recovers the job's MLflow run,
imports `training_metrics.json` from the SageMaker model artifact, and records successful upload in `.qc-cli.json`.
Use `--force` to upload the metrics again.
### `infra` ### `infra`
``` ```
@@ -163,7 +168,7 @@ Uploads use `s3.bucket` and `s3.data_prefix` from `config.yaml`. File uploads de
``` ```
qc-cli train start Submit a SageMaker training job qc-cli train start Submit a SageMaker training job
qc-cli train start --wait Submit, wait, and finalize MLflow tracking qc-cli train start --upload-metrics Submit, wait, and upload metrics
qc-cli train status [job-name] Show job status; defaults to the last submitted job qc-cli train status [job-name] Show job status; defaults to the last submitted job
qc-cli train list List recent training jobs qc-cli train list List recent training jobs
qc-cli train list --limit 3 Show a custom number of recent jobs qc-cli train list --limit 3 Show a custom number of recent jobs
@@ -171,7 +176,7 @@ qc-cli train list --limit 3 Show a custom number of recent jobs
`train start` uses `s3://<bucket>/<data_prefix>/` as the training channel and writes outputs under `s3://<bucket>/<model_prefix>/`. If `sagemaker.training.source_dir` is set, the CLI packages that directory, uploads it beside the job output prefix, and passes `sagemaker_program`/`sagemaker_submit_directory` to the SageMaker container. `train start` uses `s3://<bucket>/<data_prefix>/` as the training channel and writes outputs under `s3://<bucket>/<model_prefix>/`. If `sagemaker.training.source_dir` is set, the CLI packages that directory, uploads it beside the job output prefix, and passes `sagemaker_program`/`sagemaker_submit_directory` to the SageMaker container.
`train start --wait` checks SageMaker every 30 seconds by default. Use `--poll-interval <seconds>` to choose another positive interval. Stopping the local command does not stop the SageMaker job. `train start --upload-metrics` checks SageMaker every 30 seconds by default, then uploads metrics after completion. Use `--poll-interval <seconds>` to choose another positive interval. Stopping the local command does not stop the SageMaker job.
The expected output artifact is SageMakers `model.tar.gz`, normally containing the trained model file your container writes to `/opt/ml/model`. The expected output artifact is SageMakers `model.tar.gz`, normally containing the trained model file your container writes to `/opt/ml/model`.
@@ -219,15 +224,19 @@ The CLI uses neutral experiment naming for trained artifacts and reserves releas
Current behavior: Current behavior:
1. `qc-cli train start` submits a SageMaker training job. 1. `qc-cli train start` submits a SageMaker training job.
2. `qc-cli train status` or `qc-cli train start --wait` finalizes the MLflow run after the job reaches a terminal state. `--wait` polls every 30 seconds by default. 2. `qc-cli train status` finalizes the MLflow run and registers completed model artifacts.
3. If the job completed and `mlflow.register_trained_models` is enabled, the SageMaker `model.tar.gz` is registered as a new MLflow model version with: 3. `qc-cli train start --upload-metrics` polls every 30 seconds by default, then uploads per-epoch metrics after completion.
4. `qc-cli mlflow upload-metrics [job-name]` uploads or retries metrics for an existing completed job.
5. If the job completed and `mlflow.register_trained_models` is enabled, the SageMaker `model.tar.gz` is registered as a new MLflow model version with:
- `qc_cli.stage=experiment` - `qc_cli.stage=experiment`
- `qc_cli.artifact_kind=trained_source` - `qc_cli.artifact_kind=trained_source`
- `qc_cli.source=sagemaker` - `qc_cli.source=sagemaker`
4. The MLflow alias `experiment-latest` points at the most recently registered experiment version. 6. The MLflow alias `experiment-latest` points at the most recently registered experiment version.
5. AI Hub upload commands create deployable derived artifacts from a trained-source experiment or local ONNX model. 7. AI Hub upload commands create deployable derived artifacts from a trained-source experiment or local ONNX model.
Training scripts can include a `training_metrics.json` file in the SageMaker model directory. During finalization, the CLI logs its ordered metrics to the associated MLflow run using each epoch as the MLflow step and stores the JSON as a run artifact: Training scripts can include a `training_metrics.json` file in the SageMaker model directory. The explicit metrics
upload command logs its ordered metrics to the associated MLflow run using each epoch as the MLflow step and stores
the JSON as a run artifact:
```json ```json
{ {
@@ -239,7 +248,9 @@ Training scripts can include a `training_metrics.json` file in the SageMaker mod
} }
``` ```
Metric names must be non-empty strings, values must be finite numbers, and steps must be non-negative, unique, and strictly increasing. Missing or malformed metrics produce a warning but do not block model registration. Metric names must be non-empty strings, values must be finite numbers, and steps must be non-negative, unique, and
strictly increasing. A missing or malformed metrics artifact fails the upload command without affecting the trained
model or model registration.
Future release aliases such as `v1` or `production` can point at a selected deployable artifact. Future release aliases such as `v1` or `production` can point at a selected deployable artifact.

View File

@@ -156,11 +156,17 @@ qc-cli train status qc-cli-YYYYMMDD-HHMMSS
To submit the job, wait for completion, and automatically import metrics and register the model, run: To submit the job, wait for completion, and automatically import metrics and register the model, run:
```bash ```bash
qc-cli train start --wait qc-cli train start --upload-metrics
``` ```
The default polling interval is 30 seconds. It can be changed with `--poll-interval <seconds>`. The default polling interval is 30 seconds. It can be changed with `--poll-interval <seconds>`.
The metrics can be also submitted using:
```bash
qc-cli mlflow upload-metrics
```
## SageMaker Outputs ## SageMaker Outputs
When the job completes, SageMaker packages the files written under `/opt/ml/model` into `model.tar.gz`. When the job completes, SageMaker packages the files written under `/opt/ml/model` into `model.tar.gz`.
@@ -176,7 +182,9 @@ training_metrics.json
The archive is stored under the configured `s3.model_prefix`. The archive is stored under the configured `s3.model_prefix`.
During MLflow finalization, `training_metrics.json` provides per-epoch training and validation losses, precision, recall, mAP@0.50, mAP@0.50:0.95, and learning rates. For object detection, mAP and precision/recall are more meaningful than classification accuracy when assessing model quality. The `mlflow upload-metrics` command imports `training_metrics.json`, which provides per-epoch training and validation
losses, precision, recall, mAP@0.50, mAP@0.50:0.95, and learning rates. For object detection, mAP and precision/recall
are more meaningful than classification accuracy when assessing model quality.
## 6. Configure Qualcomm AI Hub ## 6. Configure Qualcomm AI Hub

View File

@@ -2,8 +2,11 @@ import webbrowser
import typer import typer
from src import state as state_ops
from src.aws import mlflow as aws_mlflow from src.aws import mlflow as aws_mlflow
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
from src.config import MlflowMode
from src.tracking.upload import upload_training_metrics
app = typer.Typer(help="Manage MLflow tracking server access") app = typer.Typer(help="Manage MLflow tracking server access")
@@ -39,3 +42,43 @@ def open_mlflow(config: str = CONFIG_OPT) -> None:
CONSOLE.print("[green]✓[/green] Opened MLflow UI in your browser.") CONSOLE.print("[green]✓[/green] Opened MLflow UI in your browser.")
else: else:
CONSOLE.print("[yellow]Could not open a browser automatically. Open the URL above manually.[/yellow]") CONSOLE.print("[yellow]Could not open a browser automatically. Open the URL above manually.[/yellow]")
@app.command(name="upload-metrics")
def upload_metrics(
job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"),
force: bool = typer.Option(False, "--force", help="Upload again even if this job is marked as uploaded"),
config: str = CONFIG_OPT,
) -> None:
"""Upload a completed training job's metric history to MLflow."""
cfg = load_cfg(config)
if cfg.mlflow.mode is MlflowMode.disabled:
CONSOLE.print("[red]MLflow is disabled in config.yaml.[/red]")
raise typer.Exit(1)
st = state_ops.store(config)
if not job_name:
job_name = st.get_last_training_job()
if not job_name:
CONSOLE.print(
"[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]"
)
raise typer.Exit(1)
if st.get_training_job(job_name).get("mlflow_metrics_uploaded") and not force:
CONSOLE.print(f"[green]✓[/green] Metrics already uploaded for [cyan]{job_name}[/cyan].")
return
try:
run_id = upload_training_metrics(
job_name=job_name,
config_path=config,
cfg=cfg,
force=force,
)
except Exception as e:
CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]")
raise typer.Exit(1)
CONSOLE.print(f"[green]✓[/green] Uploaded training metrics for [cyan]{job_name}[/cyan].")
CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]")

View File

@@ -12,6 +12,7 @@ from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
from src.config import Config, MlflowMode from src.config import Config, MlflowMode
from src.infra.state import read_infra_state from src.infra.state import read_infra_state
from src.tracking.mlflow import MlflowTracker from src.tracking.mlflow import MlflowTracker
from src.tracking.upload import upload_training_metrics
app = typer.Typer(help="Manage SageMaker training jobs") app = typer.Typer(help="Manage SageMaker training jobs")
@@ -102,7 +103,7 @@ def _finalize_terminal_job(
) )
def _wait_for_training_job( def _wait_and_upload_metrics(
*, *,
job_name: str, job_name: str,
poll_interval: int, poll_interval: int,
@@ -123,12 +124,21 @@ def _wait_for_training_job(
previous_status = training_status.status previous_status = training_status.status
if training_status.status in _TERMINAL_STATUSES: if training_status.status in _TERMINAL_STATUSES:
_print_training_status(training_status) _print_training_status(training_status)
_finalize_terminal_job( if training_status.status != "Completed":
config_path=config_path, raise typer.Exit(1)
cfg=cfg, try:
status=training_status, run_id = upload_training_metrics(
command="train start --wait", job_name=job_name,
) config_path=config_path,
cfg=cfg,
)
CONSOLE.print(f"[green]✓[/green] Uploaded training metrics to MLflow run [cyan]{run_id}[/cyan].")
except Exception as e:
CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]")
CONSOLE.print(
f"Retry with [cyan]qc-cli mlflow upload-metrics {job_name}[/cyan]."
)
raise typer.Exit(1)
job_state = st.get_training_job(job_name) job_state = st.get_training_job(job_name)
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled: if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]") CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
@@ -141,18 +151,26 @@ def _wait_for_training_job(
@app.command() @app.command()
def start( def start(
wait: bool = typer.Option(False, "--wait", help="Wait for completion and finalize MLflow tracking"), upload_metrics: bool = typer.Option(
False,
"--upload-metrics",
help="Wait for completion, then upload training metrics to MLflow",
),
poll_interval: int = typer.Option( poll_interval: int = typer.Option(
DEFAULT_POLL_INTERVAL_SECONDS, DEFAULT_POLL_INTERVAL_SECONDS,
"--poll-interval", "--poll-interval",
min=1, min=1,
help="Seconds between status checks when --wait is used", help="Seconds between status checks when --upload-metrics is used",
), ),
config: str = CONFIG_OPT, config: str = CONFIG_OPT,
) -> None: ) -> None:
"""Submit a SageMaker training job.""" """Submit a SageMaker training job."""
cfg = load_cfg(config) cfg = load_cfg(config)
if upload_metrics and cfg.mlflow.mode is MlflowMode.disabled:
CONSOLE.print("[red]--upload-metrics requires MLflow to be enabled in config.yaml.[/red]")
raise typer.Exit(1)
if not cfg.sagemaker.training.image_uri: if not cfg.sagemaker.training.image_uri:
CONSOLE.print("[red]sagemaker.training.image_uri is required in config.yaml.[/red]") CONSOLE.print("[red]sagemaker.training.image_uri is required in config.yaml.[/red]")
CONSOLE.print( CONSOLE.print(
@@ -189,12 +207,20 @@ def start(
st = state_ops.store(config) st = state_ops.store(config)
st.set_last_training_job(job_name) st.set_last_training_job(job_name)
run_id = tracker.start_training_run( try:
training_job, run_id = tracker.start_training_run(
region=cfg.aws.region, training_job,
profile=cfg.aws.profile, region=cfg.aws.region,
role_arn=role_arn, profile=cfg.aws.profile,
) role_arn=role_arn,
)
except Exception as e:
run_id = None
CONSOLE.print(f"[yellow]MLflow run creation failed: {e}[/yellow]")
CONSOLE.print(
"The SageMaker job is still running. Upload metrics after completion with "
f"[cyan]qc-cli mlflow upload-metrics {job_name}[/cyan]."
)
if run_id: if run_id:
st.update_training_job(job_name, mlflow_run_id=run_id) st.update_training_job(job_name, mlflow_run_id=run_id)
@@ -202,8 +228,8 @@ def start(
if run_id: if run_id:
CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]") CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]")
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]") CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
if wait: if upload_metrics:
_wait_for_training_job( _wait_and_upload_metrics(
job_name=job_name, job_name=job_name,
poll_interval=poll_interval, poll_interval=poll_interval,
config_path=config, config_path=config,

View File

@@ -31,6 +31,17 @@ class Tracker(Protocol):
command: str, command: str,
) -> FinalizeResult: ... ) -> FinalizeResult: ...
def ensure_training_run(self, job_name: str) -> str: ...
def upload_training_metrics(
self,
*,
run_id: str,
training_job_status: Any,
region: str,
profile: str,
) -> None: ...
@dataclass(frozen=True) @dataclass(frozen=True)
class NoopTracker: class NoopTracker:
@@ -48,6 +59,19 @@ class NoopTracker:
) -> FinalizeResult: ) -> FinalizeResult:
return FinalizeResult() return FinalizeResult()
def ensure_training_run(self, job_name: str) -> str:
raise RuntimeError("MLflow is disabled.")
def upload_training_metrics(
self,
*,
run_id: str,
training_job_status: Any,
region: str,
profile: str,
) -> None:
raise RuntimeError("MLflow is disabled.")
@dataclass(frozen=True) @dataclass(frozen=True)
class MlflowTracker: class MlflowTracker:
@@ -73,7 +97,6 @@ class MlflowTracker:
tracking_server_name, tracking_server_name,
) )
mlflow.set_tracking_uri(tracking_uri) mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment(cfg.mlflow.experiment_name)
return cls( return cls(
tracking_uri=tracking_uri, tracking_uri=tracking_uri,
@@ -83,34 +106,33 @@ class MlflowTracker:
) )
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
run = mlflow.start_run(run_name=training_job.job_name) mlflow.set_experiment(self.experiment_name)
run_id = str(run.info.run_id) with mlflow.start_run(run_name=training_job.job_name) as run:
run_id = str(run.info.run_id)
params = { params = {
"aws.region": region, "aws.region": region,
"aws.profile": profile, "aws.profile": profile,
"sagemaker.role_arn": role_arn, "sagemaker.role_arn": role_arn,
"sagemaker.job_name": training_job.job_name,
"sagemaker.training_image": training_job.image_uri,
"sagemaker.instance_type": training_job.instance_type,
"sagemaker.instance_count": training_job.instance_count,
"sagemaker.s3_train_uri": training_job.s3_train_uri,
"sagemaker.s3_output_path": training_job.s3_output_path,
"sagemaker.entry_point": training_job.entry_point,
"sagemaker.source_dir": training_job.source_dir,
}
self._log_params(params)
self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()})
mlflow.set_tags(
{
"qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source",
"qc_cli.source": "sagemaker",
"qc_cli.command": "train start",
"sagemaker.job_name": training_job.job_name, "sagemaker.job_name": training_job.job_name,
"sagemaker.training_image": training_job.image_uri,
"sagemaker.instance_type": training_job.instance_type,
"sagemaker.instance_count": training_job.instance_count,
"sagemaker.s3_train_uri": training_job.s3_train_uri,
"sagemaker.s3_output_path": training_job.s3_output_path,
"sagemaker.entry_point": training_job.entry_point,
"sagemaker.source_dir": training_job.source_dir,
} }
) self._log_params(params)
mlflow.end_run() self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()})
mlflow.set_tags(
{
"qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source",
"qc_cli.source": "sagemaker",
"qc_cli.command": "train start",
"sagemaker.job_name": training_job.job_name,
}
)
return run_id return run_id
def finalize_training_run( def finalize_training_run(
@@ -125,7 +147,6 @@ class MlflowTracker:
if not run_id: if not run_id:
return FinalizeResult() return FinalizeResult()
warnings: list[str] = []
with mlflow.start_run(run_id=run_id): with mlflow.start_run(run_id=run_id):
self._log_params( self._log_params(
{ {
@@ -137,22 +158,14 @@ class MlflowTracker:
} }
) )
self._log_final_metrics(training_job_status.raw) self._log_final_metrics(training_job_status.raw)
if training_job_status.status == "Completed" and training_job_status.model_artifacts:
warnings.extend(
self._log_training_metrics(
training_job_status.model_artifacts,
region=region,
profile=profile,
)
)
mlflow.set_tag("qc_cli.command", command) mlflow.set_tag("qc_cli.command", command)
if training_job_status.status != "Completed" or not training_job_status.model_artifacts: if training_job_status.status != "Completed" or not training_job_status.model_artifacts:
mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status) mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status)
return FinalizeResult(warnings=tuple(warnings)) return FinalizeResult()
if not self.register_trained_models: if not self.register_trained_models:
return FinalizeResult(warnings=tuple(warnings)) return FinalizeResult()
client = MlflowClient() client = MlflowClient()
self._ensure_registered_model(client, self.registered_model_name) self._ensure_registered_model(client, self.registered_model_name)
@@ -171,7 +184,61 @@ class MlflowTracker:
client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number) client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number)
mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name) mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name)
mlflow.set_tag("qc_cli.registered_model_version", version_number) mlflow.set_tag("qc_cli.registered_model_version", version_number)
return FinalizeResult(registered_model_version=version_number, warnings=tuple(warnings)) return FinalizeResult(registered_model_version=version_number)
def ensure_training_run(self, job_name: str) -> str:
client = MlflowClient()
experiment = client.get_experiment_by_name(self.experiment_name)
if experiment is None:
experiment_id = mlflow.create_experiment(self.experiment_name)
else:
experiment_id = experiment.experiment_id
for run in client.search_runs([experiment_id], max_results=1000):
if run.data.tags.get("sagemaker.job_name") == job_name:
return str(run.info.run_id)
run = client.create_run(
experiment_id,
run_name=job_name,
tags={
"qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source",
"qc_cli.source": "sagemaker",
"qc_cli.command": "mlflow upload-metrics",
"sagemaker.job_name": job_name,
},
)
return str(run.info.run_id)
def upload_training_metrics(
self,
*,
run_id: str,
training_job_status: Any,
region: str,
profile: str,
) -> None:
if not training_job_status.model_artifacts:
raise ValueError(f"Training job '{training_job_status.name}' has no model artifacts.")
with mlflow.start_run(run_id=run_id):
self._log_params(
{
"sagemaker.training_status": training_job_status.status,
"sagemaker.created_at": training_job_status.created,
"sagemaker.modified_at": training_job_status.modified,
"sagemaker.model_artifacts": training_job_status.model_artifacts,
"sagemaker.failure_reason": training_job_status.failure_reason,
}
)
self._log_final_metrics(training_job_status.raw)
self._log_training_metrics(
training_job_status.model_artifacts,
region=region,
profile=profile,
)
mlflow.set_tag("qc_cli.command", "mlflow upload-metrics")
def _log_params(self, params: dict[str, Any]) -> None: def _log_params(self, params: dict[str, Any]) -> None:
cleaned = {key: str(value) for key, value in params.items() if value is not None} cleaned = {key: str(value) for key, value in params.items() if value is not None}
@@ -188,28 +255,24 @@ class MlflowTracker:
if metrics: if metrics:
mlflow.log_metrics(metrics) mlflow.log_metrics(metrics)
def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> list[str]: def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> None:
try: with tempfile.TemporaryDirectory(prefix="qc-cli-metrics-") as temp_dir:
with tempfile.TemporaryDirectory(prefix="qc-cli-metrics-") as temp_dir: archive_path = s3.download_file(
archive_path = s3.download_file( region,
region, profile,
profile, model_artifacts,
model_artifacts, os.path.join(temp_dir, "model.tar.gz"),
os.path.join(temp_dir, "model.tar.gz"), )
) metrics_data = read_training_metrics_from_tar(archive_path)
metrics_data = read_training_metrics_from_tar(archive_path) if metrics_data is None:
if metrics_data is None: raise ValueError(f"No {METRICS_ARTIFACT_NAME} found in the SageMaker model artifact.")
return [f"No {METRICS_ARTIFACT_NAME} found in the SageMaker model artifact."] metrics = parse_training_metrics(metrics_data)
metrics = parse_training_metrics(metrics_data) for metric_step in metrics.steps:
for metric_step in metrics.steps: if metric_step.metrics:
if metric_step.metrics: mlflow.log_metrics(metric_step.metrics, step=metric_step.step)
mlflow.log_metrics(metric_step.metrics, step=metric_step.step) if metrics.summary:
if metrics.summary: mlflow.log_metrics(metrics.summary)
mlflow.log_metrics(metrics.summary) mlflow.log_dict(metrics.raw, METRICS_ARTIFACT_NAME)
mlflow.log_dict(metrics.raw, METRICS_ARTIFACT_NAME)
except Exception as exc:
return [f"Could not import training metrics: {exc}"]
return []
def _ensure_registered_model(self, client: MlflowClient, name: str) -> None: def _ensure_registered_model(self, client: MlflowClient, name: str) -> None:
try: try:

38
src/tracking/upload.py Normal file
View File

@@ -0,0 +1,38 @@
from src import state as state_ops
from src.aws import sagemaker as sm_ops
from src.config import Config, MlflowMode
from src.tracking.mlflow import MlflowTracker
def upload_training_metrics(
*,
job_name: str,
config_path: str,
cfg: Config,
force: bool = False,
) -> str:
if cfg.mlflow.mode is MlflowMode.disabled:
raise RuntimeError("MLflow is disabled in config.yaml.")
st = state_ops.store(config_path)
job_state = st.get_training_job(job_name)
if job_state.get("mlflow_metrics_uploaded") and not force:
return str(job_state.get("mlflow_run_id") or "")
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
if status.status != "Completed":
raise RuntimeError(
f"Training job {job_name} is {status.status}; metrics can be uploaded only after completion."
)
tracker = MlflowTracker.from_config(cfg)
run_id = str(job_state.get("mlflow_run_id") or tracker.ensure_training_run(job_name))
st.update_training_job(job_name, mlflow_run_id=run_id)
tracker.upload_training_metrics(
run_id=run_id,
training_job_status=status,
region=cfg.aws.region,
profile=cfg.aws.profile,
)
st.update_training_job(job_name, mlflow_metrics_uploaded=True)
return run_id