From 53e886a535822082b979351640a2c08c5b01cf11 Mon Sep 17 00:00:00 2001 From: slalom Date: Fri, 12 Jun 2026 11:57:27 -0400 Subject: [PATCH] update --- README.md | 6 +- examples/meter-detection/README.md | 4 +- src/commands/train.py | 109 +++++++++++++++-------------- 3 files changed, 61 insertions(+), 58 deletions(-) diff --git a/README.md b/README.md index 770bedd..1901afc 100644 --- a/README.md +++ b/README.md @@ -163,15 +163,15 @@ Uploads use `s3.bucket` and `s3.data_prefix` from `config.yaml`. File uploads de ``` qc-cli train start Submit a SageMaker training job +qc-cli train start --wait Submit, wait, and finalize MLflow tracking qc-cli train status [job-name] Show job status; defaults to the last submitted job -qc-cli train wait [job-name] Wait for completion and finalize MLflow tracking qc-cli train list List recent training jobs qc-cli train list --limit 3 Show a custom number of recent jobs ``` `train start` uses `s3:////` as the training channel and writes outputs under `s3:////`. If `sagemaker.training.source_dir` is set, the CLI packages that directory, uploads it beside the job output prefix, and passes `sagemaker_program`/`sagemaker_submit_directory` to the SageMaker container. -`train wait` checks SageMaker every 30 seconds by default. Use `--poll-interval ` to choose another positive interval. Stopping the local command does not stop the SageMaker job. +`train start --wait` checks SageMaker every 30 seconds by default. Use `--poll-interval ` to choose another positive interval. Stopping the local command does not stop the SageMaker job. The expected output artifact is SageMaker’s `model.tar.gz`, normally containing the trained model file your container writes to `/opt/ml/model`. @@ -219,7 +219,7 @@ The CLI uses neutral experiment naming for trained artifacts and reserves releas Current behavior: 1. `qc-cli train start` submits a SageMaker training job. -2. `qc-cli train status` or `qc-cli train wait` finalizes the MLflow run after the job reaches a terminal state. `train wait` blocks and polls every 30 seconds by default. +2. `qc-cli train status` or `qc-cli train start --wait` finalizes the MLflow run after the job reaches a terminal state. `--wait` polls every 30 seconds by default. 3. If the job completed and `mlflow.register_trained_models` is enabled, the SageMaker `model.tar.gz` is registered as a new MLflow model version with: - `qc_cli.stage=experiment` - `qc_cli.artifact_kind=trained_source` diff --git a/examples/meter-detection/README.md b/examples/meter-detection/README.md index a85a3a5..e6d5c27 100644 --- a/examples/meter-detection/README.md +++ b/examples/meter-detection/README.md @@ -153,10 +153,10 @@ Or pass the job name explicitly: qc-cli train status qc-cli-YYYYMMDD-HHMMSS ``` -To wait for completion and automatically import metrics and register the model, run: +To submit the job, wait for completion, and automatically import metrics and register the model, run: ```bash -qc-cli train wait +qc-cli train start --wait ``` The default polling interval is 30 seconds. It can be changed with `--poll-interval `. diff --git a/src/commands/train.py b/src/commands/train.py index 5958514..3c38927 100644 --- a/src/commands/train.py +++ b/src/commands/train.py @@ -102,8 +102,54 @@ def _finalize_terminal_job( ) +def _wait_for_training_job( + *, + job_name: str, + poll_interval: int, + config_path: str, + cfg: Config, +) -> None: + st = state_ops.store(config_path) + previous_status: str | None = None + try: + while True: + training_status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name) + if training_status.status != previous_status: + color = _STATUS_COLOR.get(training_status.status, "white") + CONSOLE.print( + f"Job [cyan]{training_status.name}[/cyan]: " + f"[{color}]{training_status.status}[/{color}]" + ) + previous_status = training_status.status + if training_status.status in _TERMINAL_STATUSES: + _print_training_status(training_status) + _finalize_terminal_job( + config_path=config_path, + cfg=cfg, + status=training_status, + command="train start --wait", + ) + job_state = st.get_training_job(job_name) + if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled: + CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]") + return + time.sleep(poll_interval) + except KeyboardInterrupt: + CONSOLE.print("[yellow]Stopped waiting. The SageMaker training job is still running.[/yellow]") + raise typer.Exit(130) + + @app.command() -def start(config: str = CONFIG_OPT) -> None: +def start( + wait: bool = typer.Option(False, "--wait", help="Wait for completion and finalize MLflow tracking"), + poll_interval: int = typer.Option( + DEFAULT_POLL_INTERVAL_SECONDS, + "--poll-interval", + min=1, + help="Seconds between status checks when --wait is used", + ), + config: str = CONFIG_OPT, +) -> None: """Submit a SageMaker training job.""" cfg = load_cfg(config) @@ -156,7 +202,15 @@ def start(config: str = CONFIG_OPT) -> None: if run_id: CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]") CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]") - CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]") + if wait: + _wait_for_training_job( + job_name=job_name, + poll_interval=poll_interval, + config_path=config, + cfg=cfg, + ) + else: + CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]") @app.command() @@ -185,57 +239,6 @@ def status( CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]") -@app.command() -def wait( - job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"), - poll_interval: int = typer.Option( - DEFAULT_POLL_INTERVAL_SECONDS, - "--poll-interval", - min=1, - help="Seconds between SageMaker status checks", - ), - config: str = CONFIG_OPT, -) -> None: - """Wait for a training job and finalize its MLflow run.""" - cfg = load_cfg(config) - st = state_ops.store(config) - if not job_name: - job_name = st.get_last_training_job() - if not job_name: - CONSOLE.print( - "[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]" - ) - raise typer.Exit(1) - - previous_status: str | None = None - try: - while True: - training_status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name) - if training_status.status != previous_status: - color = _STATUS_COLOR.get(training_status.status, "white") - CONSOLE.print( - f"Job [cyan]{training_status.name}[/cyan]: " - f"[{color}]{training_status.status}[/{color}]" - ) - previous_status = training_status.status - if training_status.status in _TERMINAL_STATUSES: - _print_training_status(training_status) - _finalize_terminal_job( - config_path=config, - cfg=cfg, - status=training_status, - command="train wait", - ) - job_state = st.get_training_job(job_name) - if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled: - CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]") - return - time.sleep(poll_interval) - except KeyboardInterrupt: - CONSOLE.print("[yellow]Stopped waiting. The SageMaker training job is still running.[/yellow]") - raise typer.Exit(130) - - @app.command(name="list") def list_jobs( limit: int = typer.Option(10, "--limit", "-n", help="Number of jobs to show"),