update
This commit is contained in:
@@ -102,8 +102,54 @@ def _finalize_terminal_job(
|
||||
)
|
||||
|
||||
|
||||
def _wait_for_training_job(
|
||||
*,
|
||||
job_name: str,
|
||||
poll_interval: int,
|
||||
config_path: str,
|
||||
cfg: Config,
|
||||
) -> None:
|
||||
st = state_ops.store(config_path)
|
||||
previous_status: str | None = None
|
||||
try:
|
||||
while True:
|
||||
training_status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
||||
if training_status.status != previous_status:
|
||||
color = _STATUS_COLOR.get(training_status.status, "white")
|
||||
CONSOLE.print(
|
||||
f"Job [cyan]{training_status.name}[/cyan]: "
|
||||
f"[{color}]{training_status.status}[/{color}]"
|
||||
)
|
||||
previous_status = training_status.status
|
||||
if training_status.status in _TERMINAL_STATUSES:
|
||||
_print_training_status(training_status)
|
||||
_finalize_terminal_job(
|
||||
config_path=config_path,
|
||||
cfg=cfg,
|
||||
status=training_status,
|
||||
command="train start --wait",
|
||||
)
|
||||
job_state = st.get_training_job(job_name)
|
||||
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
|
||||
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
|
||||
return
|
||||
time.sleep(poll_interval)
|
||||
except KeyboardInterrupt:
|
||||
CONSOLE.print("[yellow]Stopped waiting. The SageMaker training job is still running.[/yellow]")
|
||||
raise typer.Exit(130)
|
||||
|
||||
|
||||
@app.command()
|
||||
def start(config: str = CONFIG_OPT) -> None:
|
||||
def start(
|
||||
wait: bool = typer.Option(False, "--wait", help="Wait for completion and finalize MLflow tracking"),
|
||||
poll_interval: int = typer.Option(
|
||||
DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
"--poll-interval",
|
||||
min=1,
|
||||
help="Seconds between status checks when --wait is used",
|
||||
),
|
||||
config: str = CONFIG_OPT,
|
||||
) -> None:
|
||||
"""Submit a SageMaker training job."""
|
||||
cfg = load_cfg(config)
|
||||
|
||||
@@ -156,7 +202,15 @@ def start(config: str = CONFIG_OPT) -> None:
|
||||
if run_id:
|
||||
CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]")
|
||||
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
|
||||
CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]")
|
||||
if wait:
|
||||
_wait_for_training_job(
|
||||
job_name=job_name,
|
||||
poll_interval=poll_interval,
|
||||
config_path=config,
|
||||
cfg=cfg,
|
||||
)
|
||||
else:
|
||||
CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]")
|
||||
|
||||
|
||||
@app.command()
|
||||
@@ -185,57 +239,6 @@ def status(
|
||||
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
|
||||
|
||||
|
||||
@app.command()
|
||||
def wait(
|
||||
job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"),
|
||||
poll_interval: int = typer.Option(
|
||||
DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
"--poll-interval",
|
||||
min=1,
|
||||
help="Seconds between SageMaker status checks",
|
||||
),
|
||||
config: str = CONFIG_OPT,
|
||||
) -> None:
|
||||
"""Wait for a training job and finalize its MLflow run."""
|
||||
cfg = load_cfg(config)
|
||||
st = state_ops.store(config)
|
||||
if not job_name:
|
||||
job_name = st.get_last_training_job()
|
||||
if not job_name:
|
||||
CONSOLE.print(
|
||||
"[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]"
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
previous_status: str | None = None
|
||||
try:
|
||||
while True:
|
||||
training_status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
||||
if training_status.status != previous_status:
|
||||
color = _STATUS_COLOR.get(training_status.status, "white")
|
||||
CONSOLE.print(
|
||||
f"Job [cyan]{training_status.name}[/cyan]: "
|
||||
f"[{color}]{training_status.status}[/{color}]"
|
||||
)
|
||||
previous_status = training_status.status
|
||||
if training_status.status in _TERMINAL_STATUSES:
|
||||
_print_training_status(training_status)
|
||||
_finalize_terminal_job(
|
||||
config_path=config,
|
||||
cfg=cfg,
|
||||
status=training_status,
|
||||
command="train wait",
|
||||
)
|
||||
job_state = st.get_training_job(job_name)
|
||||
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
|
||||
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
|
||||
return
|
||||
time.sleep(poll_interval)
|
||||
except KeyboardInterrupt:
|
||||
CONSOLE.print("[yellow]Stopped waiting. The SageMaker training job is still running.[/yellow]")
|
||||
raise typer.Exit(130)
|
||||
|
||||
|
||||
@app.command(name="list")
|
||||
def list_jobs(
|
||||
limit: int = typer.Option(10, "--limit", "-n", help="Number of jobs to show"),
|
||||
|
||||
Reference in New Issue
Block a user