Merge branch 'main' into ml-flow
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
from rich.table import Table
|
||||
@@ -7,6 +8,8 @@ from src import state as state_ops
|
||||
from src.aws import iam
|
||||
from src.aws import sagemaker as sm_ops
|
||||
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
||||
from src.config import Config
|
||||
from src.infra.state import read_infra_state
|
||||
from src.tracking.mlflow import MlflowTracker
|
||||
|
||||
app = typer.Typer(help="Manage SageMaker training jobs")
|
||||
@@ -28,6 +31,23 @@ def _tracker(cfg):
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
def _config_dir(config_path: str) -> str:
|
||||
return str(Path(config_path).parent)
|
||||
|
||||
|
||||
def _sagemaker_role_arn(config_path: str, cfg: Config) -> str:
|
||||
state = read_infra_state(_config_dir(config_path))
|
||||
role_arn = state.get("outputs", {}).get("SageMakerRoleArn")
|
||||
if role_arn:
|
||||
return str(role_arn)
|
||||
if cfg.sagemaker.role_name:
|
||||
role_arn = iam.get_role_arn(cfg.aws.profile, cfg.sagemaker.role_name)
|
||||
if role_arn:
|
||||
return role_arn
|
||||
raise RuntimeError(f"IAM role '{cfg.sagemaker.role_name}' not found. Run 'qc-cli infra setup' first.")
|
||||
raise RuntimeError("SageMaker role not found in infra state. Run 'qc-cli infra setup' first.")
|
||||
|
||||
|
||||
@app.command()
|
||||
def start(config: str = CONFIG_OPT) -> None:
|
||||
"""Submit a SageMaker training job."""
|
||||
@@ -41,9 +61,10 @@ def start(config: str = CONFIG_OPT) -> None:
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
role_arn = iam.get_role_arn(cfg.aws.profile, cfg.sagemaker.role_name)
|
||||
if not role_arn:
|
||||
CONSOLE.print(f"[red]IAM role '{cfg.sagemaker.role_name}' not found. Run 'qc-cli infra setup' first.[/red]")
|
||||
try:
|
||||
role_arn = _sagemaker_role_arn(config, cfg)
|
||||
except RuntimeError as e:
|
||||
CONSOLE.print(f"[red]{e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
tracker = _tracker(cfg)
|
||||
|
||||
Reference in New Issue
Block a user