Files
qai-cli/src/commands/train.py
2026-06-12 11:42:26 -04:00

267 lines
9.2 KiB
Python

import time
from datetime import datetime
from pathlib import Path
import typer
from rich.table import Table
from src import state as state_ops
from src.aws import iam
from src.aws import sagemaker as sm_ops
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
from src.config import Config, MlflowMode
from src.infra.state import read_infra_state
from src.tracking.mlflow import MlflowTracker
app = typer.Typer(help="Manage SageMaker training jobs")
_STATUS_COLOR = {
"Completed": "green",
"Failed": "red",
"InProgress": "yellow",
"Stopping": "yellow",
"Stopped": "dim",
}
_TERMINAL_STATUSES = {"Completed", "Failed", "Stopped"}
DEFAULT_POLL_INTERVAL_SECONDS = 30
def _tracker(cfg):
try:
return MlflowTracker.from_config(cfg)
except Exception as e:
CONSOLE.print(f"[red]MLflow setup failed: {e}[/red]")
raise typer.Exit(1)
def _config_dir(config_path: str) -> str:
return str(Path(config_path).parent)
def _sagemaker_role_arn(config_path: str, cfg: Config) -> str:
state = read_infra_state(_config_dir(config_path))
role_arn = state.get("outputs", {}).get("SageMakerRoleArn")
if role_arn:
return str(role_arn)
if cfg.sagemaker.role_name:
role_arn = iam.get_role_arn(cfg.aws.profile, cfg.sagemaker.role_name)
if role_arn:
return role_arn
raise RuntimeError(f"IAM role '{cfg.sagemaker.role_name}' not found. Run 'qc-cli infra setup' first.")
raise RuntimeError("SageMaker role not found in infra state. Run 'qc-cli infra setup' first.")
def _print_training_status(status: sm_ops.TrainingJobStatus) -> None:
color = _STATUS_COLOR.get(status.status, "white")
CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]")
CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]")
if status.created:
CONSOLE.print(f"Created: {status.created}")
if status.model_artifacts:
CONSOLE.print(f"Artifacts: {status.model_artifacts}")
if status.failure_reason:
CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]")
def _finalize_terminal_job(
*,
config_path: str,
cfg: Config,
status: sm_ops.TrainingJobStatus,
command: str,
) -> None:
if status.status not in _TERMINAL_STATUSES:
return
st = state_ops.store(config_path)
job_state = st.get_training_job(status.name)
run_id = job_state.get("mlflow_run_id")
if not run_id or job_state.get("mlflow_finalized_status"):
return
tracker = _tracker(cfg)
result = tracker.finalize_training_run(
run_id=str(run_id),
training_job_status=status,
region=cfg.aws.region,
profile=cfg.aws.profile,
command=command,
)
updates = {"mlflow_finalized_status": status.status}
if result.registered_model_version:
updates["registered_model_version"] = result.registered_model_version
st.update_training_job(status.name, **updates)
for warning in result.warnings:
CONSOLE.print(f"[yellow]MLflow metrics warning: {warning}[/yellow]")
if result.registered_model_version:
st.set_latest_experiment_model_version(result.registered_model_version)
CONSOLE.print(
f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] "
"([cyan]experiment-latest[/cyan])"
)
@app.command()
def start(config: str = CONFIG_OPT) -> None:
"""Submit a SageMaker training job."""
cfg = load_cfg(config)
if not cfg.sagemaker.training.image_uri:
CONSOLE.print("[red]sagemaker.training.image_uri is required in config.yaml.[/red]")
CONSOLE.print(
"Find pre-built images at: "
"https://aws.github.io/deep-learning-containers/reference/available_images"
)
raise typer.Exit(1)
try:
role_arn = _sagemaker_role_arn(config, cfg)
except RuntimeError as e:
CONSOLE.print(f"[red]{e}[/red]")
raise typer.Exit(1)
tracker = _tracker(cfg)
job_name = f"qc-cli-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
s3_train_uri = f"s3://{cfg.s3.bucket}/{cfg.s3.data_prefix}"
s3_output = f"s3://{cfg.s3.bucket}/{cfg.s3.model_prefix}"
CONSOLE.print(f"Submitting training job [cyan]{job_name}[/cyan]...")
training_job = sm_ops.TrainingJobRequest(
role_arn=role_arn,
image_uri=cfg.sagemaker.training.image_uri,
instance_type=cfg.sagemaker.training.instance_type,
instance_count=cfg.sagemaker.training.instance_count,
s3_train_uri=s3_train_uri,
s3_output_path=s3_output,
job_name=job_name,
hyperparameters=cfg.sagemaker.training.hyperparameters,
entry_point=cfg.sagemaker.training.entry_point,
source_dir=cfg.sagemaker.training.source_dir,
)
sm_ops.start_training_job(cfg.aws.boto3_session, training_job)
st = state_ops.store(config)
st.set_last_training_job(job_name)
run_id = tracker.start_training_run(
training_job,
region=cfg.aws.region,
profile=cfg.aws.profile,
role_arn=role_arn,
)
if run_id:
st.update_training_job(job_name, mlflow_run_id=run_id)
CONSOLE.print(f"[green]✓[/green] Job submitted: [bold]{job_name}[/bold]")
if run_id:
CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]")
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]")
@app.command()
def status(
job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"),
config: str = CONFIG_OPT,
) -> None:
"""Show training job status."""
cfg = load_cfg(config)
st = state_ops.store(config)
if not job_name:
job_name = st.get_last_training_job()
if not job_name:
CONSOLE.print(
"[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]"
)
raise typer.Exit(1)
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
_print_training_status(status)
_finalize_terminal_job(config_path=config, cfg=cfg, status=status, command="train status")
job_state = st.get_training_job(job_name)
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
@app.command()
def wait(
job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"),
poll_interval: int = typer.Option(
DEFAULT_POLL_INTERVAL_SECONDS,
"--poll-interval",
min=1,
help="Seconds between SageMaker status checks",
),
config: str = CONFIG_OPT,
) -> None:
"""Wait for a training job and finalize its MLflow run."""
cfg = load_cfg(config)
st = state_ops.store(config)
if not job_name:
job_name = st.get_last_training_job()
if not job_name:
CONSOLE.print(
"[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]"
)
raise typer.Exit(1)
previous_status: str | None = None
try:
while True:
training_status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
if training_status.status != previous_status:
color = _STATUS_COLOR.get(training_status.status, "white")
CONSOLE.print(
f"Job [cyan]{training_status.name}[/cyan]: "
f"[{color}]{training_status.status}[/{color}]"
)
previous_status = training_status.status
if training_status.status in _TERMINAL_STATUSES:
_print_training_status(training_status)
_finalize_terminal_job(
config_path=config,
cfg=cfg,
status=training_status,
command="train wait",
)
job_state = st.get_training_job(job_name)
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
return
time.sleep(poll_interval)
except KeyboardInterrupt:
CONSOLE.print("[yellow]Stopped waiting. The SageMaker training job is still running.[/yellow]")
raise typer.Exit(130)
@app.command(name="list")
def list_jobs(
limit: int = typer.Option(10, "--limit", "-n", help="Number of jobs to show"),
config: str = CONFIG_OPT,
) -> None:
"""List recent training jobs."""
cfg = load_cfg(config)
jobs = sm_ops.list_training_jobs(cfg.aws.boto3_session, max_results=limit)
if not jobs:
CONSOLE.print("[yellow]No training jobs found.[/yellow]")
return
table = Table(title="Training Jobs")
table.add_column("Name", style="cyan")
table.add_column("Status")
table.add_column("Created")
for job in jobs:
status_value = str(job["TrainingJobStatus"])
color = _STATUS_COLOR.get(status_value, "white")
table.add_row(
str(job["TrainingJobName"]),
f"[{color}]{status_value}[/{color}]",
str(job.get("CreationTime", "")),
)
CONSOLE.print(table)