This commit is contained in:
2026-06-12 11:57:27 -04:00
parent 2d4d377051
commit 53e886a535
3 changed files with 61 additions and 58 deletions

View File

@@ -163,15 +163,15 @@ Uploads use `s3.bucket` and `s3.data_prefix` from `config.yaml`. File uploads de
```
qc-cli train start Submit a SageMaker training job
qc-cli train start --wait Submit, wait, and finalize MLflow tracking
qc-cli train status [job-name] Show job status; defaults to the last submitted job
qc-cli train wait [job-name] Wait for completion and finalize MLflow tracking
qc-cli train list List recent training jobs
qc-cli train list --limit 3 Show a custom number of recent jobs
```
`train start` uses `s3://<bucket>/<data_prefix>/` as the training channel and writes outputs under `s3://<bucket>/<model_prefix>/`. If `sagemaker.training.source_dir` is set, the CLI packages that directory, uploads it beside the job output prefix, and passes `sagemaker_program`/`sagemaker_submit_directory` to the SageMaker container.
`train wait` checks SageMaker every 30 seconds by default. Use `--poll-interval <seconds>` to choose another positive interval. Stopping the local command does not stop the SageMaker job.
`train start --wait` checks SageMaker every 30 seconds by default. Use `--poll-interval <seconds>` to choose another positive interval. Stopping the local command does not stop the SageMaker job.
The expected output artifact is SageMakers `model.tar.gz`, normally containing the trained model file your container writes to `/opt/ml/model`.
@@ -219,7 +219,7 @@ The CLI uses neutral experiment naming for trained artifacts and reserves releas
Current behavior:
1. `qc-cli train start` submits a SageMaker training job.
2. `qc-cli train status` or `qc-cli train wait` finalizes the MLflow run after the job reaches a terminal state. `train wait` blocks and polls every 30 seconds by default.
2. `qc-cli train status` or `qc-cli train start --wait` finalizes the MLflow run after the job reaches a terminal state. `--wait` polls every 30 seconds by default.
3. If the job completed and `mlflow.register_trained_models` is enabled, the SageMaker `model.tar.gz` is registered as a new MLflow model version with:
- `qc_cli.stage=experiment`
- `qc_cli.artifact_kind=trained_source`

View File

@@ -153,10 +153,10 @@ Or pass the job name explicitly:
qc-cli train status qc-cli-YYYYMMDD-HHMMSS
```
To wait for completion and automatically import metrics and register the model, run:
To submit the job, wait for completion, and automatically import metrics and register the model, run:
```bash
qc-cli train wait
qc-cli train start --wait
```
The default polling interval is 30 seconds. It can be changed with `--poll-interval <seconds>`.

View File

@@ -102,8 +102,54 @@ def _finalize_terminal_job(
)
def _wait_for_training_job(
*,
job_name: str,
poll_interval: int,
config_path: str,
cfg: Config,
) -> None:
st = state_ops.store(config_path)
previous_status: str | None = None
try:
while True:
training_status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
if training_status.status != previous_status:
color = _STATUS_COLOR.get(training_status.status, "white")
CONSOLE.print(
f"Job [cyan]{training_status.name}[/cyan]: "
f"[{color}]{training_status.status}[/{color}]"
)
previous_status = training_status.status
if training_status.status in _TERMINAL_STATUSES:
_print_training_status(training_status)
_finalize_terminal_job(
config_path=config_path,
cfg=cfg,
status=training_status,
command="train start --wait",
)
job_state = st.get_training_job(job_name)
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
return
time.sleep(poll_interval)
except KeyboardInterrupt:
CONSOLE.print("[yellow]Stopped waiting. The SageMaker training job is still running.[/yellow]")
raise typer.Exit(130)
@app.command()
def start(config: str = CONFIG_OPT) -> None:
def start(
wait: bool = typer.Option(False, "--wait", help="Wait for completion and finalize MLflow tracking"),
poll_interval: int = typer.Option(
DEFAULT_POLL_INTERVAL_SECONDS,
"--poll-interval",
min=1,
help="Seconds between status checks when --wait is used",
),
config: str = CONFIG_OPT,
) -> None:
"""Submit a SageMaker training job."""
cfg = load_cfg(config)
@@ -156,6 +202,14 @@ def start(config: str = CONFIG_OPT) -> None:
if run_id:
CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]")
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
if wait:
_wait_for_training_job(
job_name=job_name,
poll_interval=poll_interval,
config_path=config,
cfg=cfg,
)
else:
CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]")
@@ -185,57 +239,6 @@ def status(
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
@app.command()
def wait(
job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"),
poll_interval: int = typer.Option(
DEFAULT_POLL_INTERVAL_SECONDS,
"--poll-interval",
min=1,
help="Seconds between SageMaker status checks",
),
config: str = CONFIG_OPT,
) -> None:
"""Wait for a training job and finalize its MLflow run."""
cfg = load_cfg(config)
st = state_ops.store(config)
if not job_name:
job_name = st.get_last_training_job()
if not job_name:
CONSOLE.print(
"[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]"
)
raise typer.Exit(1)
previous_status: str | None = None
try:
while True:
training_status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
if training_status.status != previous_status:
color = _STATUS_COLOR.get(training_status.status, "white")
CONSOLE.print(
f"Job [cyan]{training_status.name}[/cyan]: "
f"[{color}]{training_status.status}[/{color}]"
)
previous_status = training_status.status
if training_status.status in _TERMINAL_STATUSES:
_print_training_status(training_status)
_finalize_terminal_job(
config_path=config,
cfg=cfg,
status=training_status,
command="train wait",
)
job_state = st.get_training_job(job_name)
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
return
time.sleep(poll_interval)
except KeyboardInterrupt:
CONSOLE.print("[yellow]Stopped waiting. The SageMaker training job is still running.[/yellow]")
raise typer.Exit(130)
@app.command(name="list")
def list_jobs(
limit: int = typer.Option(10, "--limit", "-n", help="Number of jobs to show"),