command to start sagemaker training

include sample training
This commit is contained in:
2026-05-25 16:48:31 -04:00
parent 62ffe163e8
commit 0e728cc193
13 changed files with 796 additions and 5 deletions

126
src/commands/train.py Normal file
View File

@@ -0,0 +1,126 @@
from datetime import datetime
import typer
from rich.table import Table
from src import state as state_ops
from src.aws import iam
from src.aws import sagemaker as sm_ops
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
app = typer.Typer(help="Manage SageMaker training jobs")
_STATUS_COLOR = {
"Completed": "green",
"Failed": "red",
"InProgress": "yellow",
"Stopping": "yellow",
"Stopped": "dim",
}
def _config_dir(config_path: str) -> str:
from pathlib import Path
return str(Path(config_path).parent)
@app.command()
def start(config: str = CONFIG_OPT) -> None:
"""Submit a SageMaker training job."""
cfg = load_cfg(config)
if not cfg.sagemaker.training.image_uri:
CONSOLE.print("[red]sagemaker.training.image_uri is required in config.yaml.[/red]")
CONSOLE.print(
"Find pre-built images at: "
"https://aws.github.io/deep-learning-containers/reference/available_images"
)
raise typer.Exit(1)
role_arn = iam.get_role_arn(cfg.aws.profile, cfg.sagemaker.role_name)
if not role_arn:
CONSOLE.print(f"[red]IAM role '{cfg.sagemaker.role_name}' not found. Run 'qc-cli infra setup' first.[/red]")
raise typer.Exit(1)
job_name = f"qc-cli-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
s3_train_uri = f"s3://{cfg.s3.bucket}/{cfg.s3.data_prefix}"
s3_output = f"s3://{cfg.s3.bucket}/{cfg.s3.model_prefix}"
CONSOLE.print(f"Submitting training job [cyan]{job_name}[/cyan]...")
training_job = sm_ops.TrainingJobRequest(
role_arn=role_arn,
image_uri=cfg.sagemaker.training.image_uri,
instance_type=cfg.sagemaker.training.instance_type,
instance_count=cfg.sagemaker.training.instance_count,
s3_train_uri=s3_train_uri,
s3_output_path=s3_output,
job_name=job_name,
hyperparameters=cfg.sagemaker.training.hyperparameters,
entry_point=cfg.sagemaker.training.entry_point,
source_dir=cfg.sagemaker.training.source_dir,
)
sm_ops.start_training_job(cfg.aws.boto3_session, training_job)
state_ops.write_state(_config_dir(config), last_training_job=job_name)
CONSOLE.print(f"[green]✓[/green] Job submitted: [bold]{job_name}[/bold]")
CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]")
@app.command()
def status(
job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"),
config: str = CONFIG_OPT,
) -> None:
"""Show training job status."""
cfg = load_cfg(config)
if not job_name:
job_name = state_ops.get_last_training_job(_config_dir(config))
if not job_name:
CONSOLE.print(
"[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]"
)
raise typer.Exit(1)
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
color = _STATUS_COLOR.get(status.status, "white")
CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]")
CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]")
if status.created:
CONSOLE.print(f"Created: {status.created}")
if status.model_artifacts:
CONSOLE.print(f"Artifacts: {status.model_artifacts}")
if status.failure_reason:
CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]")
@app.command(name="list")
def list_jobs(
limit: int = typer.Option(10, "--limit", "-n", help="Number of jobs to show"),
config: str = CONFIG_OPT,
) -> None:
"""List recent training jobs."""
cfg = load_cfg(config)
jobs = sm_ops.list_training_jobs(cfg.aws.boto3_session, max_results=limit)
if not jobs:
CONSOLE.print("[yellow]No training jobs found.[/yellow]")
return
table = Table(title="Training Jobs")
table.add_column("Name", style="cyan")
table.add_column("Status")
table.add_column("Created")
for job in jobs:
status_value = str(job["TrainingJobStatus"])
color = _STATUS_COLOR.get(status_value, "white")
table.add_row(
str(job["TrainingJobName"]),
f"[{color}]{status_value}[/{color}]",
str(job.get("CreationTime", "")),
)
CONSOLE.print(table)