247 lines
9.3 KiB
Python
247 lines
9.3 KiB
Python
from pathlib import Path
|
|
from tempfile import TemporaryDirectory
|
|
|
|
import typer
|
|
import yaml
|
|
from rich.table import Table
|
|
|
|
from src.aws import cloudformation, identity, mlflow
|
|
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
|
from src.config import Config, MlflowMode
|
|
from src.infra import provisioning
|
|
from src.infra.state import read_infra_state
|
|
|
|
app = typer.Typer(help="Manage AWS infrastructure")
|
|
|
|
|
|
@app.command()
|
|
def setup(
|
|
config: str = CONFIG_OPT,
|
|
bootstrap: bool = typer.Option(
|
|
True,
|
|
"--bootstrap/--no-bootstrap",
|
|
help="Run CDK bootstrap before deploying the application stack",
|
|
),
|
|
cloudformation_execution_policy: str | None = typer.Option(
|
|
None,
|
|
"--cloudformation-execution-policy",
|
|
help="IAM policy ARN for the CDK bootstrap CloudFormation execution role",
|
|
),
|
|
) -> None:
|
|
"""Create infrastructure with AWS CDK."""
|
|
cfg = load_cfg(config)
|
|
CONSOLE.print("[bold]Deploying infrastructure with AWS CDK...[/bold]")
|
|
|
|
try:
|
|
account_id = identity.account_id(cfg.aws.region, cfg.aws.profile)
|
|
if cfg.mlflow.mode is MlflowMode.existing:
|
|
assert cfg.mlflow.tracking_server_name is not None
|
|
with CONSOLE.status("Checking MLflow tracking server..."):
|
|
server = mlflow.describe_tracking_server(
|
|
cfg.aws.region,
|
|
cfg.aws.profile,
|
|
cfg.mlflow.tracking_server_name,
|
|
)
|
|
if server is None:
|
|
raise RuntimeError(f"MLflow tracking server not found: {cfg.mlflow.tracking_server_name}")
|
|
|
|
if bootstrap:
|
|
with CONSOLE.status("Running cdk bootstrap..."):
|
|
provisioning.bootstrap(
|
|
profile=cfg.aws.profile,
|
|
account_id=account_id,
|
|
region=cfg.aws.region,
|
|
bootstrap_qualifier=cfg.infra.effective_bootstrap_qualifier,
|
|
toolkit_stack_name=cfg.infra.effective_toolkit_stack_name,
|
|
cloudformation_execution_policy=cloudformation_execution_policy,
|
|
)
|
|
with CONSOLE.status("Running cdk deploy..."):
|
|
state = provisioning.deploy(
|
|
profile=cfg.aws.profile,
|
|
account_id=account_id,
|
|
region=cfg.aws.region,
|
|
stack_name=cfg.infra.stack_name,
|
|
bootstrap_qualifier=cfg.infra.effective_bootstrap_qualifier,
|
|
toolkit_stack_name=cfg.infra.effective_toolkit_stack_name,
|
|
config_path=config,
|
|
config_dir=str(Path(config).parent),
|
|
config_snapshot=cfg.model_dump(mode="json"),
|
|
)
|
|
except RuntimeError as e:
|
|
CONSOLE.print(f"[red]{e}[/red]")
|
|
raise typer.Exit(1)
|
|
|
|
outputs = state.get("outputs", {})
|
|
if outputs.get("DataBucketArn"):
|
|
CONSOLE.print(f"[green]✓[/green] S3 bucket: {outputs['DataBucketArn']}")
|
|
if outputs.get("SageMakerRoleArn"):
|
|
CONSOLE.print(f"[green]✓[/green] IAM role: {outputs['SageMakerRoleArn']}")
|
|
if cfg.mlflow.mode is MlflowMode.create and outputs.get("MlflowTrackingServerArn"):
|
|
mlflow_name = outputs.get("MlflowTrackingServerName", cfg.managed_mlflow_tracking_server_name)
|
|
CONSOLE.print(f"[green]✓[/green] MLflow: {mlflow_name}")
|
|
elif cfg.mlflow.mode is MlflowMode.existing:
|
|
CONSOLE.print(f"[green]✓[/green] MLflow: {cfg.mlflow.tracking_server_name}")
|
|
CONSOLE.print("\n[bold green]Infrastructure ready.[/bold green]")
|
|
|
|
|
|
@app.command()
|
|
def status(config: str = CONFIG_OPT) -> None:
|
|
"""Show current infrastructure status."""
|
|
cfg = load_cfg(config)
|
|
stack = cloudformation.stack_status(cfg.aws.region, cfg.aws.profile, cfg.infra.stack_name)
|
|
|
|
table = Table(title="Infrastructure Status")
|
|
table.add_column("Resource", style="cyan")
|
|
table.add_column("Name")
|
|
table.add_column("Status")
|
|
table.add_column("ARN / URI")
|
|
|
|
if not stack:
|
|
table.add_row("CDK Stack", cfg.infra.stack_name, "[red]missing[/red]", "-")
|
|
table.add_row("S3 Bucket", cfg.s3.bucket, "[red]unknown[/red]", "-")
|
|
table.add_row("IAM Role", cfg.sagemaker.role_name, "[red]unknown[/red]", "-")
|
|
if cfg.mlflow.mode is not MlflowMode.disabled:
|
|
table.add_row(
|
|
"MLflow",
|
|
cfg.effective_mlflow_tracking_server_name or "-",
|
|
"[red]unknown[/red]",
|
|
"-",
|
|
)
|
|
CONSOLE.print(table)
|
|
return
|
|
|
|
outputs = stack["outputs"]
|
|
table.add_row("CDK Stack", stack["name"], f"[green]{stack['status']}[/green]", "-")
|
|
table.add_row(
|
|
"S3 Bucket",
|
|
cfg.s3.bucket,
|
|
"[green]managed[/green]",
|
|
outputs.get("DataBucketArn", "-"),
|
|
)
|
|
table.add_row(
|
|
"IAM Role",
|
|
_role_name(cfg.sagemaker.role_name, outputs.get("SageMakerRoleArn", "")),
|
|
"[green]managed[/green]",
|
|
outputs.get("SageMakerRoleArn", "-"),
|
|
)
|
|
if cfg.mlflow.mode is MlflowMode.create:
|
|
table.add_row(
|
|
"MLflow",
|
|
outputs.get("MlflowTrackingServerName", cfg.managed_mlflow_tracking_server_name),
|
|
"[green]managed[/green]",
|
|
outputs.get("MlflowTrackingServerArn", outputs.get("MlflowArtifactUri", "-")),
|
|
)
|
|
elif cfg.mlflow.mode is MlflowMode.existing:
|
|
server = mlflow.describe_tracking_server(
|
|
cfg.aws.region,
|
|
cfg.aws.profile,
|
|
cfg.mlflow.tracking_server_name or "",
|
|
)
|
|
if server is None:
|
|
table.add_row("MLflow", cfg.mlflow.tracking_server_name or "-", "[red]missing[/red]", "-")
|
|
else:
|
|
table.add_row(
|
|
"MLflow",
|
|
cfg.mlflow.tracking_server_name or "-",
|
|
f"[green]{server.get('TrackingServerStatus', 'existing')}[/green]",
|
|
server.get("TrackingServerArn") or server.get("ArtifactStoreUri") or "-",
|
|
)
|
|
|
|
CONSOLE.print(table)
|
|
|
|
|
|
@app.command()
|
|
def destroy(
|
|
config: str = CONFIG_OPT,
|
|
yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation prompt"),
|
|
delete_bucket_data: bool = typer.Option(
|
|
False,
|
|
"--delete-bucket-data",
|
|
help="Delete stack-managed S3 objects and bucket instead of retaining data",
|
|
),
|
|
) -> None:
|
|
"""Destroy the CDK stack."""
|
|
cfg = _destroy_config(config)
|
|
stack_name = _destroy_stack_name(config, cfg)
|
|
bootstrap_qualifier = _destroy_bootstrap_qualifier(config, cfg)
|
|
toolkit_stack_name = _destroy_toolkit_stack_name(config, cfg)
|
|
|
|
if not yes and not delete_bucket_data:
|
|
typer.confirm(
|
|
f"Destroy CDK stack '{stack_name}' while retaining S3 bucket data?",
|
|
abort=True,
|
|
)
|
|
|
|
try:
|
|
account_id = _destroy_account_id(config, cfg)
|
|
with TemporaryDirectory() as temp_dir:
|
|
snapshot_path = Path(temp_dir) / "config.yaml"
|
|
snapshot_path.write_text(yaml.safe_dump(cfg.model_dump(mode="json"), sort_keys=False))
|
|
with CONSOLE.status("Running cdk destroy..."):
|
|
provisioning.destroy(
|
|
profile=cfg.aws.profile,
|
|
account_id=account_id,
|
|
stack_name=stack_name,
|
|
bootstrap_qualifier=bootstrap_qualifier,
|
|
toolkit_stack_name=toolkit_stack_name,
|
|
config_path=str(snapshot_path),
|
|
delete_bucket_data=delete_bucket_data,
|
|
)
|
|
except RuntimeError as e:
|
|
CONSOLE.print(f"[red]{e}[/red]")
|
|
raise typer.Exit(1)
|
|
CONSOLE.print(f"[green]✓[/green] Destroyed stack: {stack_name}")
|
|
CONSOLE.print(f"[yellow]CDK bootstrap stack retained: {toolkit_stack_name}[/yellow]")
|
|
|
|
|
|
def _destroy_config(config_path: str) -> Config:
|
|
config_dir = str(Path(config_path).parent)
|
|
state = read_infra_state(config_dir)
|
|
config_snapshot = state.get("config")
|
|
if config_snapshot:
|
|
return Config.model_validate(config_snapshot)
|
|
return load_cfg(config_path)
|
|
|
|
|
|
def _role_name(configured_name: str, role_arn: str) -> str:
|
|
if configured_name:
|
|
return configured_name
|
|
if role_arn:
|
|
return role_arn.rsplit("/", 1)[-1]
|
|
return "-"
|
|
|
|
def _destroy_account_id(config_path: str, cfg: Config) -> str:
|
|
config_dir = str(Path(config_path).parent)
|
|
state = read_infra_state(config_dir)
|
|
account_id = state.get("aws", {}).get("account_id")
|
|
if account_id:
|
|
return str(account_id)
|
|
return identity.account_id(cfg.aws.region, cfg.aws.profile)
|
|
|
|
|
|
def _destroy_stack_name(config_path: str, cfg: Config) -> str:
|
|
config_dir = str(Path(config_path).parent)
|
|
state = read_infra_state(config_dir)
|
|
stack_name = state.get("stack_name")
|
|
if stack_name:
|
|
return str(stack_name)
|
|
return cfg.infra.stack_name
|
|
|
|
|
|
def _destroy_bootstrap_qualifier(config_path: str, cfg: Config) -> str:
|
|
config_dir = str(Path(config_path).parent)
|
|
state = read_infra_state(config_dir)
|
|
bootstrap_qualifier = state.get("bootstrap_qualifier")
|
|
if bootstrap_qualifier:
|
|
return str(bootstrap_qualifier)
|
|
return cfg.infra.effective_bootstrap_qualifier
|
|
|
|
|
|
def _destroy_toolkit_stack_name(config_path: str, cfg: Config) -> str:
|
|
config_dir = str(Path(config_path).parent)
|
|
state = read_infra_state(config_dir)
|
|
toolkit_stack_name = state.get("toolkit_stack_name")
|
|
if toolkit_stack_name:
|
|
return str(toolkit_stack_name)
|
|
return cfg.infra.effective_toolkit_stack_name
|