another update

This commit is contained in:
2026-06-12 12:17:02 -04:00
parent 53e886a535
commit 5211d0af14
6 changed files with 278 additions and 89 deletions

View File

@@ -2,8 +2,11 @@ import webbrowser
import typer
from src import state as state_ops
from src.aws import mlflow as aws_mlflow
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
from src.config import MlflowMode
from src.tracking.upload import upload_training_metrics
app = typer.Typer(help="Manage MLflow tracking server access")
@@ -39,3 +42,43 @@ def open_mlflow(config: str = CONFIG_OPT) -> None:
CONSOLE.print("[green]✓[/green] Opened MLflow UI in your browser.")
else:
CONSOLE.print("[yellow]Could not open a browser automatically. Open the URL above manually.[/yellow]")
@app.command(name="upload-metrics")
def upload_metrics(
job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"),
force: bool = typer.Option(False, "--force", help="Upload again even if this job is marked as uploaded"),
config: str = CONFIG_OPT,
) -> None:
"""Upload a completed training job's metric history to MLflow."""
cfg = load_cfg(config)
if cfg.mlflow.mode is MlflowMode.disabled:
CONSOLE.print("[red]MLflow is disabled in config.yaml.[/red]")
raise typer.Exit(1)
st = state_ops.store(config)
if not job_name:
job_name = st.get_last_training_job()
if not job_name:
CONSOLE.print(
"[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]"
)
raise typer.Exit(1)
if st.get_training_job(job_name).get("mlflow_metrics_uploaded") and not force:
CONSOLE.print(f"[green]✓[/green] Metrics already uploaded for [cyan]{job_name}[/cyan].")
return
try:
run_id = upload_training_metrics(
job_name=job_name,
config_path=config,
cfg=cfg,
force=force,
)
except Exception as e:
CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]")
raise typer.Exit(1)
CONSOLE.print(f"[green]✓[/green] Uploaded training metrics for [cyan]{job_name}[/cyan].")
CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]")

View File

@@ -12,6 +12,7 @@ from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
from src.config import Config, MlflowMode
from src.infra.state import read_infra_state
from src.tracking.mlflow import MlflowTracker
from src.tracking.upload import upload_training_metrics
app = typer.Typer(help="Manage SageMaker training jobs")
@@ -102,7 +103,7 @@ def _finalize_terminal_job(
)
def _wait_for_training_job(
def _wait_and_upload_metrics(
*,
job_name: str,
poll_interval: int,
@@ -123,12 +124,21 @@ def _wait_for_training_job(
previous_status = training_status.status
if training_status.status in _TERMINAL_STATUSES:
_print_training_status(training_status)
_finalize_terminal_job(
config_path=config_path,
cfg=cfg,
status=training_status,
command="train start --wait",
)
if training_status.status != "Completed":
raise typer.Exit(1)
try:
run_id = upload_training_metrics(
job_name=job_name,
config_path=config_path,
cfg=cfg,
)
CONSOLE.print(f"[green]✓[/green] Uploaded training metrics to MLflow run [cyan]{run_id}[/cyan].")
except Exception as e:
CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]")
CONSOLE.print(
f"Retry with [cyan]qc-cli mlflow upload-metrics {job_name}[/cyan]."
)
raise typer.Exit(1)
job_state = st.get_training_job(job_name)
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
@@ -141,18 +151,26 @@ def _wait_for_training_job(
@app.command()
def start(
wait: bool = typer.Option(False, "--wait", help="Wait for completion and finalize MLflow tracking"),
upload_metrics: bool = typer.Option(
False,
"--upload-metrics",
help="Wait for completion, then upload training metrics to MLflow",
),
poll_interval: int = typer.Option(
DEFAULT_POLL_INTERVAL_SECONDS,
"--poll-interval",
min=1,
help="Seconds between status checks when --wait is used",
help="Seconds between status checks when --upload-metrics is used",
),
config: str = CONFIG_OPT,
) -> None:
"""Submit a SageMaker training job."""
cfg = load_cfg(config)
if upload_metrics and cfg.mlflow.mode is MlflowMode.disabled:
CONSOLE.print("[red]--upload-metrics requires MLflow to be enabled in config.yaml.[/red]")
raise typer.Exit(1)
if not cfg.sagemaker.training.image_uri:
CONSOLE.print("[red]sagemaker.training.image_uri is required in config.yaml.[/red]")
CONSOLE.print(
@@ -189,12 +207,20 @@ def start(
st = state_ops.store(config)
st.set_last_training_job(job_name)
run_id = tracker.start_training_run(
training_job,
region=cfg.aws.region,
profile=cfg.aws.profile,
role_arn=role_arn,
)
try:
run_id = tracker.start_training_run(
training_job,
region=cfg.aws.region,
profile=cfg.aws.profile,
role_arn=role_arn,
)
except Exception as e:
run_id = None
CONSOLE.print(f"[yellow]MLflow run creation failed: {e}[/yellow]")
CONSOLE.print(
"The SageMaker job is still running. Upload metrics after completion with "
f"[cyan]qc-cli mlflow upload-metrics {job_name}[/cyan]."
)
if run_id:
st.update_training_job(job_name, mlflow_run_id=run_id)
@@ -202,8 +228,8 @@ def start(
if run_id:
CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]")
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
if wait:
_wait_for_training_job(
if upload_metrics:
_wait_and_upload_metrics(
job_name=job_name,
poll_interval=poll_interval,
config_path=config,

View File

@@ -31,6 +31,17 @@ class Tracker(Protocol):
command: str,
) -> FinalizeResult: ...
def ensure_training_run(self, job_name: str) -> str: ...
def upload_training_metrics(
self,
*,
run_id: str,
training_job_status: Any,
region: str,
profile: str,
) -> None: ...
@dataclass(frozen=True)
class NoopTracker:
@@ -48,6 +59,19 @@ class NoopTracker:
) -> FinalizeResult:
return FinalizeResult()
def ensure_training_run(self, job_name: str) -> str:
raise RuntimeError("MLflow is disabled.")
def upload_training_metrics(
self,
*,
run_id: str,
training_job_status: Any,
region: str,
profile: str,
) -> None:
raise RuntimeError("MLflow is disabled.")
@dataclass(frozen=True)
class MlflowTracker:
@@ -73,7 +97,6 @@ class MlflowTracker:
tracking_server_name,
)
mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment(cfg.mlflow.experiment_name)
return cls(
tracking_uri=tracking_uri,
@@ -83,34 +106,33 @@ class MlflowTracker:
)
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
run = mlflow.start_run(run_name=training_job.job_name)
run_id = str(run.info.run_id)
params = {
"aws.region": region,
"aws.profile": profile,
"sagemaker.role_arn": role_arn,
"sagemaker.job_name": training_job.job_name,
"sagemaker.training_image": training_job.image_uri,
"sagemaker.instance_type": training_job.instance_type,
"sagemaker.instance_count": training_job.instance_count,
"sagemaker.s3_train_uri": training_job.s3_train_uri,
"sagemaker.s3_output_path": training_job.s3_output_path,
"sagemaker.entry_point": training_job.entry_point,
"sagemaker.source_dir": training_job.source_dir,
}
self._log_params(params)
self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()})
mlflow.set_tags(
{
"qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source",
"qc_cli.source": "sagemaker",
"qc_cli.command": "train start",
mlflow.set_experiment(self.experiment_name)
with mlflow.start_run(run_name=training_job.job_name) as run:
run_id = str(run.info.run_id)
params = {
"aws.region": region,
"aws.profile": profile,
"sagemaker.role_arn": role_arn,
"sagemaker.job_name": training_job.job_name,
"sagemaker.training_image": training_job.image_uri,
"sagemaker.instance_type": training_job.instance_type,
"sagemaker.instance_count": training_job.instance_count,
"sagemaker.s3_train_uri": training_job.s3_train_uri,
"sagemaker.s3_output_path": training_job.s3_output_path,
"sagemaker.entry_point": training_job.entry_point,
"sagemaker.source_dir": training_job.source_dir,
}
)
mlflow.end_run()
self._log_params(params)
self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()})
mlflow.set_tags(
{
"qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source",
"qc_cli.source": "sagemaker",
"qc_cli.command": "train start",
"sagemaker.job_name": training_job.job_name,
}
)
return run_id
def finalize_training_run(
@@ -125,7 +147,6 @@ class MlflowTracker:
if not run_id:
return FinalizeResult()
warnings: list[str] = []
with mlflow.start_run(run_id=run_id):
self._log_params(
{
@@ -137,22 +158,14 @@ class MlflowTracker:
}
)
self._log_final_metrics(training_job_status.raw)
if training_job_status.status == "Completed" and training_job_status.model_artifacts:
warnings.extend(
self._log_training_metrics(
training_job_status.model_artifacts,
region=region,
profile=profile,
)
)
mlflow.set_tag("qc_cli.command", command)
if training_job_status.status != "Completed" or not training_job_status.model_artifacts:
mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status)
return FinalizeResult(warnings=tuple(warnings))
return FinalizeResult()
if not self.register_trained_models:
return FinalizeResult(warnings=tuple(warnings))
return FinalizeResult()
client = MlflowClient()
self._ensure_registered_model(client, self.registered_model_name)
@@ -171,7 +184,61 @@ class MlflowTracker:
client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number)
mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name)
mlflow.set_tag("qc_cli.registered_model_version", version_number)
return FinalizeResult(registered_model_version=version_number, warnings=tuple(warnings))
return FinalizeResult(registered_model_version=version_number)
def ensure_training_run(self, job_name: str) -> str:
client = MlflowClient()
experiment = client.get_experiment_by_name(self.experiment_name)
if experiment is None:
experiment_id = mlflow.create_experiment(self.experiment_name)
else:
experiment_id = experiment.experiment_id
for run in client.search_runs([experiment_id], max_results=1000):
if run.data.tags.get("sagemaker.job_name") == job_name:
return str(run.info.run_id)
run = client.create_run(
experiment_id,
run_name=job_name,
tags={
"qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source",
"qc_cli.source": "sagemaker",
"qc_cli.command": "mlflow upload-metrics",
"sagemaker.job_name": job_name,
},
)
return str(run.info.run_id)
def upload_training_metrics(
self,
*,
run_id: str,
training_job_status: Any,
region: str,
profile: str,
) -> None:
if not training_job_status.model_artifacts:
raise ValueError(f"Training job '{training_job_status.name}' has no model artifacts.")
with mlflow.start_run(run_id=run_id):
self._log_params(
{
"sagemaker.training_status": training_job_status.status,
"sagemaker.created_at": training_job_status.created,
"sagemaker.modified_at": training_job_status.modified,
"sagemaker.model_artifacts": training_job_status.model_artifacts,
"sagemaker.failure_reason": training_job_status.failure_reason,
}
)
self._log_final_metrics(training_job_status.raw)
self._log_training_metrics(
training_job_status.model_artifacts,
region=region,
profile=profile,
)
mlflow.set_tag("qc_cli.command", "mlflow upload-metrics")
def _log_params(self, params: dict[str, Any]) -> None:
cleaned = {key: str(value) for key, value in params.items() if value is not None}
@@ -188,28 +255,24 @@ class MlflowTracker:
if metrics:
mlflow.log_metrics(metrics)
def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> list[str]:
try:
with tempfile.TemporaryDirectory(prefix="qc-cli-metrics-") as temp_dir:
archive_path = s3.download_file(
region,
profile,
model_artifacts,
os.path.join(temp_dir, "model.tar.gz"),
)
metrics_data = read_training_metrics_from_tar(archive_path)
if metrics_data is None:
return [f"No {METRICS_ARTIFACT_NAME} found in the SageMaker model artifact."]
metrics = parse_training_metrics(metrics_data)
for metric_step in metrics.steps:
if metric_step.metrics:
mlflow.log_metrics(metric_step.metrics, step=metric_step.step)
if metrics.summary:
mlflow.log_metrics(metrics.summary)
mlflow.log_dict(metrics.raw, METRICS_ARTIFACT_NAME)
except Exception as exc:
return [f"Could not import training metrics: {exc}"]
return []
def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> None:
with tempfile.TemporaryDirectory(prefix="qc-cli-metrics-") as temp_dir:
archive_path = s3.download_file(
region,
profile,
model_artifacts,
os.path.join(temp_dir, "model.tar.gz"),
)
metrics_data = read_training_metrics_from_tar(archive_path)
if metrics_data is None:
raise ValueError(f"No {METRICS_ARTIFACT_NAME} found in the SageMaker model artifact.")
metrics = parse_training_metrics(metrics_data)
for metric_step in metrics.steps:
if metric_step.metrics:
mlflow.log_metrics(metric_step.metrics, step=metric_step.step)
if metrics.summary:
mlflow.log_metrics(metrics.summary)
mlflow.log_dict(metrics.raw, METRICS_ARTIFACT_NAME)
def _ensure_registered_model(self, client: MlflowClient, name: str) -> None:
try:

38
src/tracking/upload.py Normal file
View File

@@ -0,0 +1,38 @@
from src import state as state_ops
from src.aws import sagemaker as sm_ops
from src.config import Config, MlflowMode
from src.tracking.mlflow import MlflowTracker
def upload_training_metrics(
*,
job_name: str,
config_path: str,
cfg: Config,
force: bool = False,
) -> str:
if cfg.mlflow.mode is MlflowMode.disabled:
raise RuntimeError("MLflow is disabled in config.yaml.")
st = state_ops.store(config_path)
job_state = st.get_training_job(job_name)
if job_state.get("mlflow_metrics_uploaded") and not force:
return str(job_state.get("mlflow_run_id") or "")
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
if status.status != "Completed":
raise RuntimeError(
f"Training job {job_name} is {status.status}; metrics can be uploaded only after completion."
)
tracker = MlflowTracker.from_config(cfg)
run_id = str(job_state.get("mlflow_run_id") or tracker.ensure_training_run(job_name))
st.update_training_job(job_name, mlflow_run_id=run_id)
tracker.upload_training_metrics(
run_id=run_id,
training_job_status=status,
region=cfg.aws.region,
profile=cfg.aws.profile,
)
st.update_training_job(job_name, mlflow_metrics_uploaded=True)
return run_id