update
This commit is contained in:
@@ -163,15 +163,15 @@ Uploads use `s3.bucket` and `s3.data_prefix` from `config.yaml`. File uploads de
|
||||
|
||||
```
|
||||
qc-cli train start Submit a SageMaker training job
|
||||
qc-cli train start --wait Submit, wait, and finalize MLflow tracking
|
||||
qc-cli train status [job-name] Show job status; defaults to the last submitted job
|
||||
qc-cli train wait [job-name] Wait for completion and finalize MLflow tracking
|
||||
qc-cli train list List recent training jobs
|
||||
qc-cli train list --limit 3 Show a custom number of recent jobs
|
||||
```
|
||||
|
||||
`train start` uses `s3://<bucket>/<data_prefix>/` as the training channel and writes outputs under `s3://<bucket>/<model_prefix>/`. If `sagemaker.training.source_dir` is set, the CLI packages that directory, uploads it beside the job output prefix, and passes `sagemaker_program`/`sagemaker_submit_directory` to the SageMaker container.
|
||||
|
||||
`train wait` checks SageMaker every 30 seconds by default. Use `--poll-interval <seconds>` to choose another positive interval. Stopping the local command does not stop the SageMaker job.
|
||||
`train start --wait` checks SageMaker every 30 seconds by default. Use `--poll-interval <seconds>` to choose another positive interval. Stopping the local command does not stop the SageMaker job.
|
||||
|
||||
The expected output artifact is SageMaker’s `model.tar.gz`, normally containing the trained model file your container writes to `/opt/ml/model`.
|
||||
|
||||
@@ -219,7 +219,7 @@ The CLI uses neutral experiment naming for trained artifacts and reserves releas
|
||||
Current behavior:
|
||||
|
||||
1. `qc-cli train start` submits a SageMaker training job.
|
||||
2. `qc-cli train status` or `qc-cli train wait` finalizes the MLflow run after the job reaches a terminal state. `train wait` blocks and polls every 30 seconds by default.
|
||||
2. `qc-cli train status` or `qc-cli train start --wait` finalizes the MLflow run after the job reaches a terminal state. `--wait` polls every 30 seconds by default.
|
||||
3. If the job completed and `mlflow.register_trained_models` is enabled, the SageMaker `model.tar.gz` is registered as a new MLflow model version with:
|
||||
- `qc_cli.stage=experiment`
|
||||
- `qc_cli.artifact_kind=trained_source`
|
||||
|
||||
@@ -153,10 +153,10 @@ Or pass the job name explicitly:
|
||||
qc-cli train status qc-cli-YYYYMMDD-HHMMSS
|
||||
```
|
||||
|
||||
To wait for completion and automatically import metrics and register the model, run:
|
||||
To submit the job, wait for completion, and automatically import metrics and register the model, run:
|
||||
|
||||
```bash
|
||||
qc-cli train wait
|
||||
qc-cli train start --wait
|
||||
```
|
||||
|
||||
The default polling interval is 30 seconds. It can be changed with `--poll-interval <seconds>`.
|
||||
|
||||
@@ -102,8 +102,54 @@ def _finalize_terminal_job(
|
||||
)
|
||||
|
||||
|
||||
def _wait_for_training_job(
|
||||
*,
|
||||
job_name: str,
|
||||
poll_interval: int,
|
||||
config_path: str,
|
||||
cfg: Config,
|
||||
) -> None:
|
||||
st = state_ops.store(config_path)
|
||||
previous_status: str | None = None
|
||||
try:
|
||||
while True:
|
||||
training_status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
||||
if training_status.status != previous_status:
|
||||
color = _STATUS_COLOR.get(training_status.status, "white")
|
||||
CONSOLE.print(
|
||||
f"Job [cyan]{training_status.name}[/cyan]: "
|
||||
f"[{color}]{training_status.status}[/{color}]"
|
||||
)
|
||||
previous_status = training_status.status
|
||||
if training_status.status in _TERMINAL_STATUSES:
|
||||
_print_training_status(training_status)
|
||||
_finalize_terminal_job(
|
||||
config_path=config_path,
|
||||
cfg=cfg,
|
||||
status=training_status,
|
||||
command="train start --wait",
|
||||
)
|
||||
job_state = st.get_training_job(job_name)
|
||||
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
|
||||
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
|
||||
return
|
||||
time.sleep(poll_interval)
|
||||
except KeyboardInterrupt:
|
||||
CONSOLE.print("[yellow]Stopped waiting. The SageMaker training job is still running.[/yellow]")
|
||||
raise typer.Exit(130)
|
||||
|
||||
|
||||
@app.command()
|
||||
def start(config: str = CONFIG_OPT) -> None:
|
||||
def start(
|
||||
wait: bool = typer.Option(False, "--wait", help="Wait for completion and finalize MLflow tracking"),
|
||||
poll_interval: int = typer.Option(
|
||||
DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
"--poll-interval",
|
||||
min=1,
|
||||
help="Seconds between status checks when --wait is used",
|
||||
),
|
||||
config: str = CONFIG_OPT,
|
||||
) -> None:
|
||||
"""Submit a SageMaker training job."""
|
||||
cfg = load_cfg(config)
|
||||
|
||||
@@ -156,7 +202,15 @@ def start(config: str = CONFIG_OPT) -> None:
|
||||
if run_id:
|
||||
CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]")
|
||||
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
|
||||
CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]")
|
||||
if wait:
|
||||
_wait_for_training_job(
|
||||
job_name=job_name,
|
||||
poll_interval=poll_interval,
|
||||
config_path=config,
|
||||
cfg=cfg,
|
||||
)
|
||||
else:
|
||||
CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]")
|
||||
|
||||
|
||||
@app.command()
|
||||
@@ -185,57 +239,6 @@ def status(
|
||||
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
|
||||
|
||||
|
||||
@app.command()
|
||||
def wait(
|
||||
job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"),
|
||||
poll_interval: int = typer.Option(
|
||||
DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
"--poll-interval",
|
||||
min=1,
|
||||
help="Seconds between SageMaker status checks",
|
||||
),
|
||||
config: str = CONFIG_OPT,
|
||||
) -> None:
|
||||
"""Wait for a training job and finalize its MLflow run."""
|
||||
cfg = load_cfg(config)
|
||||
st = state_ops.store(config)
|
||||
if not job_name:
|
||||
job_name = st.get_last_training_job()
|
||||
if not job_name:
|
||||
CONSOLE.print(
|
||||
"[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]"
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
previous_status: str | None = None
|
||||
try:
|
||||
while True:
|
||||
training_status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
||||
if training_status.status != previous_status:
|
||||
color = _STATUS_COLOR.get(training_status.status, "white")
|
||||
CONSOLE.print(
|
||||
f"Job [cyan]{training_status.name}[/cyan]: "
|
||||
f"[{color}]{training_status.status}[/{color}]"
|
||||
)
|
||||
previous_status = training_status.status
|
||||
if training_status.status in _TERMINAL_STATUSES:
|
||||
_print_training_status(training_status)
|
||||
_finalize_terminal_job(
|
||||
config_path=config,
|
||||
cfg=cfg,
|
||||
status=training_status,
|
||||
command="train wait",
|
||||
)
|
||||
job_state = st.get_training_job(job_name)
|
||||
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
|
||||
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
|
||||
return
|
||||
time.sleep(poll_interval)
|
||||
except KeyboardInterrupt:
|
||||
CONSOLE.print("[yellow]Stopped waiting. The SageMaker training job is still running.[/yellow]")
|
||||
raise typer.Exit(130)
|
||||
|
||||
|
||||
@app.command(name="list")
|
||||
def list_jobs(
|
||||
limit: int = typer.Option(10, "--limit", "-n", help="Number of jobs to show"),
|
||||
|
||||
Reference in New Issue
Block a user