import time from datetime import datetime from pathlib import Path import typer from rich.table import Table from src import state as state_ops from src.aws import iam from src.aws import sagemaker as sm_ops from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg from src.config import Config, MlflowMode from src.infra.state import read_infra_state from src.tracking.mlflow import MlflowTracker app = typer.Typer(help="Manage SageMaker training jobs") _STATUS_COLOR = { "Completed": "green", "Failed": "red", "InProgress": "yellow", "Stopping": "yellow", "Stopped": "dim", } _TERMINAL_STATUSES = {"Completed", "Failed", "Stopped"} DEFAULT_POLL_INTERVAL_SECONDS = 30 def _tracker(cfg): try: return MlflowTracker.from_config(cfg) except Exception as e: CONSOLE.print(f"[red]MLflow setup failed: {e}[/red]") raise typer.Exit(1) def _config_dir(config_path: str) -> str: return str(Path(config_path).parent) def _sagemaker_role_arn(config_path: str, cfg: Config) -> str: state = read_infra_state(_config_dir(config_path)) role_arn = state.get("outputs", {}).get("SageMakerRoleArn") if role_arn: return str(role_arn) if cfg.sagemaker.role_name: role_arn = iam.get_role_arn(cfg.aws.profile, cfg.sagemaker.role_name) if role_arn: return role_arn raise RuntimeError(f"IAM role '{cfg.sagemaker.role_name}' not found. Run 'qc-cli infra setup' first.") raise RuntimeError("SageMaker role not found in infra state. Run 'qc-cli infra setup' first.") def _print_training_status(status: sm_ops.TrainingJobStatus) -> None: color = _STATUS_COLOR.get(status.status, "white") CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]") CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]") if status.created: CONSOLE.print(f"Created: {status.created}") if status.model_artifacts: CONSOLE.print(f"Artifacts: {status.model_artifacts}") if status.failure_reason: CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]") def _finalize_terminal_job( *, config_path: str, cfg: Config, status: sm_ops.TrainingJobStatus, command: str, ) -> None: if status.status not in _TERMINAL_STATUSES: return st = state_ops.store(config_path) job_state = st.get_training_job(status.name) run_id = job_state.get("mlflow_run_id") if not run_id or job_state.get("mlflow_finalized_status"): return tracker = _tracker(cfg) result = tracker.finalize_training_run( run_id=str(run_id), training_job_status=status, region=cfg.aws.region, profile=cfg.aws.profile, command=command, ) updates = {"mlflow_finalized_status": status.status} if result.registered_model_version: updates["registered_model_version"] = result.registered_model_version st.update_training_job(status.name, **updates) for warning in result.warnings: CONSOLE.print(f"[yellow]MLflow metrics warning: {warning}[/yellow]") if result.registered_model_version: st.set_latest_experiment_model_version(result.registered_model_version) CONSOLE.print( f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] " "([cyan]experiment-latest[/cyan])" ) @app.command() def start(config: str = CONFIG_OPT) -> None: """Submit a SageMaker training job.""" cfg = load_cfg(config) if not cfg.sagemaker.training.image_uri: CONSOLE.print("[red]sagemaker.training.image_uri is required in config.yaml.[/red]") CONSOLE.print( "Find pre-built images at: " "https://aws.github.io/deep-learning-containers/reference/available_images" ) raise typer.Exit(1) try: role_arn = _sagemaker_role_arn(config, cfg) except RuntimeError as e: CONSOLE.print(f"[red]{e}[/red]") raise typer.Exit(1) tracker = _tracker(cfg) job_name = f"qc-cli-{datetime.now().strftime('%Y%m%d-%H%M%S')}" s3_train_uri = f"s3://{cfg.s3.bucket}/{cfg.s3.data_prefix}" s3_output = f"s3://{cfg.s3.bucket}/{cfg.s3.model_prefix}" CONSOLE.print(f"Submitting training job [cyan]{job_name}[/cyan]...") training_job = sm_ops.TrainingJobRequest( role_arn=role_arn, image_uri=cfg.sagemaker.training.image_uri, instance_type=cfg.sagemaker.training.instance_type, instance_count=cfg.sagemaker.training.instance_count, s3_train_uri=s3_train_uri, s3_output_path=s3_output, job_name=job_name, hyperparameters=cfg.sagemaker.training.hyperparameters, entry_point=cfg.sagemaker.training.entry_point, source_dir=cfg.sagemaker.training.source_dir, ) sm_ops.start_training_job(cfg.aws.boto3_session, training_job) st = state_ops.store(config) st.set_last_training_job(job_name) run_id = tracker.start_training_run( training_job, region=cfg.aws.region, profile=cfg.aws.profile, role_arn=role_arn, ) if run_id: st.update_training_job(job_name, mlflow_run_id=run_id) CONSOLE.print(f"[green]✓[/green] Job submitted: [bold]{job_name}[/bold]") if run_id: CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]") CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]") CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]") @app.command() def status( job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"), config: str = CONFIG_OPT, ) -> None: """Show training job status.""" cfg = load_cfg(config) st = state_ops.store(config) if not job_name: job_name = st.get_last_training_job() if not job_name: CONSOLE.print( "[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]" ) raise typer.Exit(1) status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name) _print_training_status(status) _finalize_terminal_job(config_path=config, cfg=cfg, status=status, command="train status") job_state = st.get_training_job(job_name) if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled: CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]") @app.command() def wait( job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"), poll_interval: int = typer.Option( DEFAULT_POLL_INTERVAL_SECONDS, "--poll-interval", min=1, help="Seconds between SageMaker status checks", ), config: str = CONFIG_OPT, ) -> None: """Wait for a training job and finalize its MLflow run.""" cfg = load_cfg(config) st = state_ops.store(config) if not job_name: job_name = st.get_last_training_job() if not job_name: CONSOLE.print( "[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]" ) raise typer.Exit(1) previous_status: str | None = None try: while True: training_status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name) if training_status.status != previous_status: color = _STATUS_COLOR.get(training_status.status, "white") CONSOLE.print( f"Job [cyan]{training_status.name}[/cyan]: " f"[{color}]{training_status.status}[/{color}]" ) previous_status = training_status.status if training_status.status in _TERMINAL_STATUSES: _print_training_status(training_status) _finalize_terminal_job( config_path=config, cfg=cfg, status=training_status, command="train wait", ) job_state = st.get_training_job(job_name) if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled: CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]") return time.sleep(poll_interval) except KeyboardInterrupt: CONSOLE.print("[yellow]Stopped waiting. The SageMaker training job is still running.[/yellow]") raise typer.Exit(130) @app.command(name="list") def list_jobs( limit: int = typer.Option(10, "--limit", "-n", help="Number of jobs to show"), config: str = CONFIG_OPT, ) -> None: """List recent training jobs.""" cfg = load_cfg(config) jobs = sm_ops.list_training_jobs(cfg.aws.boto3_session, max_results=limit) if not jobs: CONSOLE.print("[yellow]No training jobs found.[/yellow]") return table = Table(title="Training Jobs") table.add_column("Name", style="cyan") table.add_column("Status") table.add_column("Created") for job in jobs: status_value = str(job["TrainingJobStatus"]) color = _STATUS_COLOR.get(status_value, "white") table.add_row( str(job["TrainingJobName"]), f"[{color}]{status_value}[/{color}]", str(job.get("CreationTime", "")), ) CONSOLE.print(table)