This commit is contained in:
2026-06-12 11:57:27 -04:00
parent 2d4d377051
commit 53e886a535
3 changed files with 61 additions and 58 deletions

View File

@@ -102,8 +102,54 @@ def _finalize_terminal_job(
)
def _wait_for_training_job(
*,
job_name: str,
poll_interval: int,
config_path: str,
cfg: Config,
) -> None:
st = state_ops.store(config_path)
previous_status: str | None = None
try:
while True:
training_status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
if training_status.status != previous_status:
color = _STATUS_COLOR.get(training_status.status, "white")
CONSOLE.print(
f"Job [cyan]{training_status.name}[/cyan]: "
f"[{color}]{training_status.status}[/{color}]"
)
previous_status = training_status.status
if training_status.status in _TERMINAL_STATUSES:
_print_training_status(training_status)
_finalize_terminal_job(
config_path=config_path,
cfg=cfg,
status=training_status,
command="train start --wait",
)
job_state = st.get_training_job(job_name)
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
return
time.sleep(poll_interval)
except KeyboardInterrupt:
CONSOLE.print("[yellow]Stopped waiting. The SageMaker training job is still running.[/yellow]")
raise typer.Exit(130)
@app.command()
def start(config: str = CONFIG_OPT) -> None:
def start(
wait: bool = typer.Option(False, "--wait", help="Wait for completion and finalize MLflow tracking"),
poll_interval: int = typer.Option(
DEFAULT_POLL_INTERVAL_SECONDS,
"--poll-interval",
min=1,
help="Seconds between status checks when --wait is used",
),
config: str = CONFIG_OPT,
) -> None:
"""Submit a SageMaker training job."""
cfg = load_cfg(config)
@@ -156,7 +202,15 @@ def start(config: str = CONFIG_OPT) -> None:
if run_id:
CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]")
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]")
if wait:
_wait_for_training_job(
job_name=job_name,
poll_interval=poll_interval,
config_path=config,
cfg=cfg,
)
else:
CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]")
@app.command()
@@ -185,57 +239,6 @@ def status(
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
@app.command()
def wait(
job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"),
poll_interval: int = typer.Option(
DEFAULT_POLL_INTERVAL_SECONDS,
"--poll-interval",
min=1,
help="Seconds between SageMaker status checks",
),
config: str = CONFIG_OPT,
) -> None:
"""Wait for a training job and finalize its MLflow run."""
cfg = load_cfg(config)
st = state_ops.store(config)
if not job_name:
job_name = st.get_last_training_job()
if not job_name:
CONSOLE.print(
"[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]"
)
raise typer.Exit(1)
previous_status: str | None = None
try:
while True:
training_status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
if training_status.status != previous_status:
color = _STATUS_COLOR.get(training_status.status, "white")
CONSOLE.print(
f"Job [cyan]{training_status.name}[/cyan]: "
f"[{color}]{training_status.status}[/{color}]"
)
previous_status = training_status.status
if training_status.status in _TERMINAL_STATUSES:
_print_training_status(training_status)
_finalize_terminal_job(
config_path=config,
cfg=cfg,
status=training_status,
command="train wait",
)
job_state = st.get_training_job(job_name)
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
return
time.sleep(poll_interval)
except KeyboardInterrupt:
CONSOLE.print("[yellow]Stopped waiting. The SageMaker training job is still running.[/yellow]")
raise typer.Exit(130)
@app.command(name="list")
def list_jobs(
limit: int = typer.Option(10, "--limit", "-n", help="Number of jobs to show"),