command to start sagemaker training
include sample training
This commit is contained in:
126
src/commands/train.py
Normal file
126
src/commands/train.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from datetime import datetime
|
||||
|
||||
import typer
|
||||
from rich.table import Table
|
||||
|
||||
from src import state as state_ops
|
||||
from src.aws import iam
|
||||
from src.aws import sagemaker as sm_ops
|
||||
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
||||
|
||||
app = typer.Typer(help="Manage SageMaker training jobs")
|
||||
|
||||
_STATUS_COLOR = {
|
||||
"Completed": "green",
|
||||
"Failed": "red",
|
||||
"InProgress": "yellow",
|
||||
"Stopping": "yellow",
|
||||
"Stopped": "dim",
|
||||
}
|
||||
|
||||
|
||||
def _config_dir(config_path: str) -> str:
|
||||
from pathlib import Path
|
||||
return str(Path(config_path).parent)
|
||||
|
||||
|
||||
@app.command()
|
||||
def start(config: str = CONFIG_OPT) -> None:
|
||||
"""Submit a SageMaker training job."""
|
||||
cfg = load_cfg(config)
|
||||
|
||||
if not cfg.sagemaker.training.image_uri:
|
||||
CONSOLE.print("[red]sagemaker.training.image_uri is required in config.yaml.[/red]")
|
||||
CONSOLE.print(
|
||||
"Find pre-built images at: "
|
||||
"https://aws.github.io/deep-learning-containers/reference/available_images"
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
role_arn = iam.get_role_arn(cfg.aws.profile, cfg.sagemaker.role_name)
|
||||
if not role_arn:
|
||||
CONSOLE.print(f"[red]IAM role '{cfg.sagemaker.role_name}' not found. Run 'qc-cli infra setup' first.[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
job_name = f"qc-cli-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
|
||||
s3_train_uri = f"s3://{cfg.s3.bucket}/{cfg.s3.data_prefix}"
|
||||
s3_output = f"s3://{cfg.s3.bucket}/{cfg.s3.model_prefix}"
|
||||
|
||||
CONSOLE.print(f"Submitting training job [cyan]{job_name}[/cyan]...")
|
||||
training_job = sm_ops.TrainingJobRequest(
|
||||
role_arn=role_arn,
|
||||
image_uri=cfg.sagemaker.training.image_uri,
|
||||
instance_type=cfg.sagemaker.training.instance_type,
|
||||
instance_count=cfg.sagemaker.training.instance_count,
|
||||
s3_train_uri=s3_train_uri,
|
||||
s3_output_path=s3_output,
|
||||
job_name=job_name,
|
||||
hyperparameters=cfg.sagemaker.training.hyperparameters,
|
||||
entry_point=cfg.sagemaker.training.entry_point,
|
||||
source_dir=cfg.sagemaker.training.source_dir,
|
||||
)
|
||||
sm_ops.start_training_job(cfg.aws.boto3_session, training_job)
|
||||
|
||||
state_ops.write_state(_config_dir(config), last_training_job=job_name)
|
||||
|
||||
CONSOLE.print(f"[green]✓[/green] Job submitted: [bold]{job_name}[/bold]")
|
||||
CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]")
|
||||
|
||||
|
||||
@app.command()
|
||||
def status(
|
||||
job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"),
|
||||
config: str = CONFIG_OPT,
|
||||
) -> None:
|
||||
"""Show training job status."""
|
||||
cfg = load_cfg(config)
|
||||
|
||||
if not job_name:
|
||||
job_name = state_ops.get_last_training_job(_config_dir(config))
|
||||
if not job_name:
|
||||
CONSOLE.print(
|
||||
"[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]"
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
||||
color = _STATUS_COLOR.get(status.status, "white")
|
||||
|
||||
CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]")
|
||||
CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]")
|
||||
if status.created:
|
||||
CONSOLE.print(f"Created: {status.created}")
|
||||
if status.model_artifacts:
|
||||
CONSOLE.print(f"Artifacts: {status.model_artifacts}")
|
||||
if status.failure_reason:
|
||||
CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]")
|
||||
|
||||
|
||||
@app.command(name="list")
|
||||
def list_jobs(
|
||||
limit: int = typer.Option(10, "--limit", "-n", help="Number of jobs to show"),
|
||||
config: str = CONFIG_OPT,
|
||||
) -> None:
|
||||
"""List recent training jobs."""
|
||||
cfg = load_cfg(config)
|
||||
jobs = sm_ops.list_training_jobs(cfg.aws.boto3_session, max_results=limit)
|
||||
|
||||
if not jobs:
|
||||
CONSOLE.print("[yellow]No training jobs found.[/yellow]")
|
||||
return
|
||||
|
||||
table = Table(title="Training Jobs")
|
||||
table.add_column("Name", style="cyan")
|
||||
table.add_column("Status")
|
||||
table.add_column("Created")
|
||||
|
||||
for job in jobs:
|
||||
status_value = str(job["TrainingJobStatus"])
|
||||
color = _STATUS_COLOR.get(status_value, "white")
|
||||
table.add_row(
|
||||
str(job["TrainingJobName"]),
|
||||
f"[{color}]{status_value}[/{color}]",
|
||||
str(job.get("CreationTime", "")),
|
||||
)
|
||||
|
||||
CONSOLE.print(table)
|
||||
Reference in New Issue
Block a user