from pathlib import Path from tempfile import TemporaryDirectory import typer import yaml from rich.table import Table from src.aws import cloudformation, identity, mlflow from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg from src.config import Config, MlflowMode from src.infra import provisioning from src.infra.state import read_infra_state app = typer.Typer(help="Manage AWS infrastructure") @app.command() def setup( config: str = CONFIG_OPT, bootstrap: bool = typer.Option( True, "--bootstrap/--no-bootstrap", help="Run CDK bootstrap before deploying the application stack", ), cloudformation_execution_policy: str | None = typer.Option( None, "--cloudformation-execution-policy", help="IAM policy ARN for the CDK bootstrap CloudFormation execution role", ), ) -> None: """Create infrastructure with AWS CDK.""" cfg = load_cfg(config) CONSOLE.print("[bold]Deploying infrastructure with AWS CDK...[/bold]") try: account_id = identity.account_id(cfg.aws.region, cfg.aws.profile) if cfg.mlflow.mode is MlflowMode.existing: assert cfg.mlflow.tracking_server_name is not None with CONSOLE.status("Checking MLflow tracking server..."): server = mlflow.describe_tracking_server( cfg.aws.region, cfg.aws.profile, cfg.mlflow.tracking_server_name, ) if server is None: raise RuntimeError(f"MLflow tracking server not found: {cfg.mlflow.tracking_server_name}") if bootstrap: with CONSOLE.status("Running cdk bootstrap..."): provisioning.bootstrap( profile=cfg.aws.profile, account_id=account_id, region=cfg.aws.region, bootstrap_qualifier=cfg.infra.effective_bootstrap_qualifier, toolkit_stack_name=cfg.infra.effective_toolkit_stack_name, cloudformation_execution_policy=cloudformation_execution_policy, ) with CONSOLE.status("Running cdk deploy..."): state = provisioning.deploy( profile=cfg.aws.profile, account_id=account_id, region=cfg.aws.region, stack_name=cfg.infra.stack_name, bootstrap_qualifier=cfg.infra.effective_bootstrap_qualifier, toolkit_stack_name=cfg.infra.effective_toolkit_stack_name, config_path=config, config_dir=str(Path(config).parent), config_snapshot=cfg.model_dump(mode="json"), ) except RuntimeError as e: CONSOLE.print(f"[red]{e}[/red]") raise typer.Exit(1) outputs = state.get("outputs", {}) if outputs.get("DataBucketArn"): CONSOLE.print(f"[green]✓[/green] S3 bucket: {outputs['DataBucketArn']}") if outputs.get("SageMakerRoleArn"): CONSOLE.print(f"[green]✓[/green] IAM role: {outputs['SageMakerRoleArn']}") if cfg.mlflow.mode is MlflowMode.create and outputs.get("MlflowTrackingServerArn"): mlflow_name = outputs.get("MlflowTrackingServerName", cfg.managed_mlflow_tracking_server_name) CONSOLE.print(f"[green]✓[/green] MLflow: {mlflow_name}") elif cfg.mlflow.mode is MlflowMode.existing: CONSOLE.print(f"[green]✓[/green] MLflow: {cfg.mlflow.tracking_server_name}") CONSOLE.print("\n[bold green]Infrastructure ready.[/bold green]") @app.command() def status(config: str = CONFIG_OPT) -> None: """Show current infrastructure status.""" cfg = load_cfg(config) stack = cloudformation.stack_status(cfg.aws.region, cfg.aws.profile, cfg.infra.stack_name) table = Table(title="Infrastructure Status") table.add_column("Resource", style="cyan") table.add_column("Name") table.add_column("Status") table.add_column("ARN / URI") if not stack: table.add_row("CDK Stack", cfg.infra.stack_name, "[red]missing[/red]", "-") table.add_row("S3 Bucket", cfg.s3.bucket, "[red]unknown[/red]", "-") table.add_row("IAM Role", cfg.sagemaker.role_name, "[red]unknown[/red]", "-") if cfg.mlflow.mode is not MlflowMode.disabled: table.add_row( "MLflow", cfg.effective_mlflow_tracking_server_name or "-", "[red]unknown[/red]", "-", ) CONSOLE.print(table) return outputs = stack["outputs"] table.add_row("CDK Stack", stack["name"], f"[green]{stack['status']}[/green]", "-") table.add_row( "S3 Bucket", cfg.s3.bucket, "[green]managed[/green]", outputs.get("DataBucketArn", "-"), ) table.add_row( "IAM Role", _role_name(cfg.sagemaker.role_name, outputs.get("SageMakerRoleArn", "")), "[green]managed[/green]", outputs.get("SageMakerRoleArn", "-"), ) if cfg.mlflow.mode is MlflowMode.create: table.add_row( "MLflow", outputs.get("MlflowTrackingServerName", cfg.managed_mlflow_tracking_server_name), "[green]managed[/green]", outputs.get("MlflowTrackingServerArn", outputs.get("MlflowArtifactUri", "-")), ) elif cfg.mlflow.mode is MlflowMode.existing: server = mlflow.describe_tracking_server( cfg.aws.region, cfg.aws.profile, cfg.mlflow.tracking_server_name or "", ) if server is None: table.add_row("MLflow", cfg.mlflow.tracking_server_name or "-", "[red]missing[/red]", "-") else: table.add_row( "MLflow", cfg.mlflow.tracking_server_name or "-", f"[green]{server.get('TrackingServerStatus', 'existing')}[/green]", server.get("TrackingServerArn") or server.get("ArtifactStoreUri") or "-", ) CONSOLE.print(table) @app.command() def destroy( config: str = CONFIG_OPT, yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation prompt"), delete_bucket_data: bool = typer.Option( False, "--delete-bucket-data", help="Delete stack-managed S3 objects and bucket instead of retaining data", ), ) -> None: """Destroy the CDK stack.""" cfg = _destroy_config(config) stack_name = _destroy_stack_name(config, cfg) bootstrap_qualifier = _destroy_bootstrap_qualifier(config, cfg) toolkit_stack_name = _destroy_toolkit_stack_name(config, cfg) if not yes and not delete_bucket_data: typer.confirm( f"Destroy CDK stack '{stack_name}' while retaining S3 bucket data?", abort=True, ) try: account_id = _destroy_account_id(config, cfg) with TemporaryDirectory() as temp_dir: snapshot_path = Path(temp_dir) / "config.yaml" snapshot_path.write_text(yaml.safe_dump(cfg.model_dump(mode="json"), sort_keys=False)) with CONSOLE.status("Running cdk destroy..."): provisioning.destroy( profile=cfg.aws.profile, account_id=account_id, stack_name=stack_name, bootstrap_qualifier=bootstrap_qualifier, toolkit_stack_name=toolkit_stack_name, config_path=str(snapshot_path), delete_bucket_data=delete_bucket_data, ) except RuntimeError as e: CONSOLE.print(f"[red]{e}[/red]") raise typer.Exit(1) CONSOLE.print(f"[green]✓[/green] Destroyed stack: {stack_name}") CONSOLE.print(f"[yellow]CDK bootstrap stack retained: {toolkit_stack_name}[/yellow]") def _destroy_config(config_path: str) -> Config: config_dir = str(Path(config_path).parent) state = read_infra_state(config_dir) config_snapshot = state.get("config") if config_snapshot: return Config.model_validate(config_snapshot) return load_cfg(config_path) def _role_name(configured_name: str, role_arn: str) -> str: if configured_name: return configured_name if role_arn: return role_arn.rsplit("/", 1)[-1] return "-" def _destroy_account_id(config_path: str, cfg: Config) -> str: config_dir = str(Path(config_path).parent) state = read_infra_state(config_dir) account_id = state.get("aws", {}).get("account_id") if account_id: return str(account_id) return identity.account_id(cfg.aws.region, cfg.aws.profile) def _destroy_stack_name(config_path: str, cfg: Config) -> str: config_dir = str(Path(config_path).parent) state = read_infra_state(config_dir) stack_name = state.get("stack_name") if stack_name: return str(stack_name) return cfg.infra.stack_name def _destroy_bootstrap_qualifier(config_path: str, cfg: Config) -> str: config_dir = str(Path(config_path).parent) state = read_infra_state(config_dir) bootstrap_qualifier = state.get("bootstrap_qualifier") if bootstrap_qualifier: return str(bootstrap_qualifier) return cfg.infra.effective_bootstrap_qualifier def _destroy_toolkit_stack_name(config_path: str, cfg: Config) -> str: config_dir = str(Path(config_path).parent) state = read_infra_state(config_dir) toolkit_stack_name = state.get("toolkit_stack_name") if toolkit_stack_name: return str(toolkit_stack_name) return cfg.infra.effective_toolkit_stack_name