WIP
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
@@ -21,6 +22,8 @@ _STATUS_COLOR = {
|
||||
"Stopping": "yellow",
|
||||
"Stopped": "dim",
|
||||
}
|
||||
_TERMINAL_STATUSES = {"Completed", "Failed", "Stopped"}
|
||||
DEFAULT_POLL_INTERVAL_SECONDS = 30
|
||||
|
||||
|
||||
def _tracker(cfg):
|
||||
@@ -48,6 +51,57 @@ def _sagemaker_role_arn(config_path: str, cfg: Config) -> str:
|
||||
raise RuntimeError("SageMaker role not found in infra state. Run 'qc-cli infra setup' first.")
|
||||
|
||||
|
||||
def _print_training_status(status: sm_ops.TrainingJobStatus) -> None:
|
||||
color = _STATUS_COLOR.get(status.status, "white")
|
||||
CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]")
|
||||
CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]")
|
||||
if status.created:
|
||||
CONSOLE.print(f"Created: {status.created}")
|
||||
if status.model_artifacts:
|
||||
CONSOLE.print(f"Artifacts: {status.model_artifacts}")
|
||||
if status.failure_reason:
|
||||
CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]")
|
||||
|
||||
|
||||
def _finalize_terminal_job(
|
||||
*,
|
||||
config_path: str,
|
||||
cfg: Config,
|
||||
status: sm_ops.TrainingJobStatus,
|
||||
command: str,
|
||||
) -> None:
|
||||
if status.status not in _TERMINAL_STATUSES:
|
||||
return
|
||||
|
||||
st = state_ops.store(config_path)
|
||||
job_state = st.get_training_job(status.name)
|
||||
run_id = job_state.get("mlflow_run_id")
|
||||
if not run_id or job_state.get("mlflow_finalized_status"):
|
||||
return
|
||||
|
||||
tracker = _tracker(cfg)
|
||||
result = tracker.finalize_training_run(
|
||||
run_id=str(run_id),
|
||||
training_job_status=status,
|
||||
region=cfg.aws.region,
|
||||
profile=cfg.aws.profile,
|
||||
command=command,
|
||||
)
|
||||
updates = {"mlflow_finalized_status": status.status}
|
||||
if result.registered_model_version:
|
||||
updates["registered_model_version"] = result.registered_model_version
|
||||
st.update_training_job(status.name, **updates)
|
||||
|
||||
for warning in result.warnings:
|
||||
CONSOLE.print(f"[yellow]MLflow metrics warning: {warning}[/yellow]")
|
||||
if result.registered_model_version:
|
||||
st.set_latest_experiment_model_version(result.registered_model_version)
|
||||
CONSOLE.print(
|
||||
f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] "
|
||||
"([cyan]experiment-latest[/cyan])"
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def start(config: str = CONFIG_OPT) -> None:
|
||||
"""Submit a SageMaker training job."""
|
||||
@@ -123,37 +177,65 @@ def status(
|
||||
raise typer.Exit(1)
|
||||
|
||||
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
||||
color = _STATUS_COLOR.get(status.status, "white")
|
||||
|
||||
CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]")
|
||||
CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]")
|
||||
if status.created:
|
||||
CONSOLE.print(f"Created: {status.created}")
|
||||
if status.model_artifacts:
|
||||
CONSOLE.print(f"Artifacts: {status.model_artifacts}")
|
||||
if status.failure_reason:
|
||||
CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]")
|
||||
_print_training_status(status)
|
||||
_finalize_terminal_job(config_path=config, cfg=cfg, status=status, command="train status")
|
||||
|
||||
job_state = st.get_training_job(job_name)
|
||||
run_id = job_state.get("mlflow_run_id")
|
||||
already_registered = job_state.get("registered_model_version")
|
||||
if run_id and not already_registered and status.status in {"Completed", "Failed", "Stopped"}:
|
||||
tracker = _tracker(cfg)
|
||||
version = tracker.finalize_training_run(
|
||||
run_id=str(run_id),
|
||||
training_job_status=status,
|
||||
)
|
||||
updates = {"mlflow_finalized_status": status.status}
|
||||
if version:
|
||||
updates["registered_model_version"] = version
|
||||
st.update_training_job(job_name, **updates)
|
||||
if version:
|
||||
st.set_latest_experiment_model_version(version)
|
||||
CONSOLE.print(f"MLflow model version: [cyan]{version}[/cyan] ([cyan]experiment-latest[/cyan])")
|
||||
if run_id and cfg.mlflow.mode is not MlflowMode.disabled:
|
||||
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
|
||||
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
|
||||
|
||||
|
||||
@app.command()
|
||||
def wait(
|
||||
job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"),
|
||||
poll_interval: int = typer.Option(
|
||||
DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
"--poll-interval",
|
||||
min=1,
|
||||
help="Seconds between SageMaker status checks",
|
||||
),
|
||||
config: str = CONFIG_OPT,
|
||||
) -> None:
|
||||
"""Wait for a training job and finalize its MLflow run."""
|
||||
cfg = load_cfg(config)
|
||||
st = state_ops.store(config)
|
||||
if not job_name:
|
||||
job_name = st.get_last_training_job()
|
||||
if not job_name:
|
||||
CONSOLE.print(
|
||||
"[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]"
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
previous_status: str | None = None
|
||||
try:
|
||||
while True:
|
||||
training_status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
||||
if training_status.status != previous_status:
|
||||
color = _STATUS_COLOR.get(training_status.status, "white")
|
||||
CONSOLE.print(
|
||||
f"Job [cyan]{training_status.name}[/cyan]: "
|
||||
f"[{color}]{training_status.status}[/{color}]"
|
||||
)
|
||||
previous_status = training_status.status
|
||||
if training_status.status in _TERMINAL_STATUSES:
|
||||
_print_training_status(training_status)
|
||||
_finalize_terminal_job(
|
||||
config_path=config,
|
||||
cfg=cfg,
|
||||
status=training_status,
|
||||
command="train wait",
|
||||
)
|
||||
job_state = st.get_training_job(job_name)
|
||||
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
|
||||
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
|
||||
return
|
||||
time.sleep(poll_interval)
|
||||
except KeyboardInterrupt:
|
||||
CONSOLE.print("[yellow]Stopped waiting. The SageMaker training job is still running.[/yellow]")
|
||||
raise typer.Exit(130)
|
||||
|
||||
|
||||
@app.command(name="list")
|
||||
def list_jobs(
|
||||
limit: int = typer.Option(10, "--limit", "-n", help="Number of jobs to show"),
|
||||
|
||||
Reference in New Issue
Block a user