Files
qai-cli/src/commands/infra.py

247 lines
9.3 KiB
Python

from pathlib import Path
from tempfile import TemporaryDirectory
import typer
import yaml
from rich.table import Table
from src.aws import cloudformation, identity, mlflow
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
from src.config import Config, MlflowMode
from src.infra import provisioning
from src.infra.state import read_infra_state
app = typer.Typer(help="Manage AWS infrastructure")
@app.command()
def setup(
config: str = CONFIG_OPT,
bootstrap: bool = typer.Option(
True,
"--bootstrap/--no-bootstrap",
help="Run CDK bootstrap before deploying the application stack",
),
cloudformation_execution_policy: str | None = typer.Option(
None,
"--cloudformation-execution-policy",
help="IAM policy ARN for the CDK bootstrap CloudFormation execution role",
),
) -> None:
"""Create infrastructure with AWS CDK."""
cfg = load_cfg(config)
CONSOLE.print("[bold]Deploying infrastructure with AWS CDK...[/bold]")
try:
account_id = identity.account_id(cfg.aws.region, cfg.aws.profile)
if cfg.mlflow.mode is MlflowMode.existing:
assert cfg.mlflow.tracking_server_name is not None
with CONSOLE.status("Checking MLflow tracking server..."):
server = mlflow.describe_tracking_server(
cfg.aws.region,
cfg.aws.profile,
cfg.mlflow.tracking_server_name,
)
if server is None:
raise RuntimeError(f"MLflow tracking server not found: {cfg.mlflow.tracking_server_name}")
if bootstrap:
with CONSOLE.status("Running cdk bootstrap..."):
provisioning.bootstrap(
profile=cfg.aws.profile,
account_id=account_id,
region=cfg.aws.region,
bootstrap_qualifier=cfg.infra.effective_bootstrap_qualifier,
toolkit_stack_name=cfg.infra.effective_toolkit_stack_name,
cloudformation_execution_policy=cloudformation_execution_policy,
)
with CONSOLE.status("Running cdk deploy..."):
state = provisioning.deploy(
profile=cfg.aws.profile,
account_id=account_id,
region=cfg.aws.region,
stack_name=cfg.infra.stack_name,
bootstrap_qualifier=cfg.infra.effective_bootstrap_qualifier,
toolkit_stack_name=cfg.infra.effective_toolkit_stack_name,
config_path=config,
config_dir=str(Path(config).parent),
config_snapshot=cfg.model_dump(mode="json"),
)
except RuntimeError as e:
CONSOLE.print(f"[red]{e}[/red]")
raise typer.Exit(1)
outputs = state.get("outputs", {})
if outputs.get("DataBucketArn"):
CONSOLE.print(f"[green]✓[/green] S3 bucket: {outputs['DataBucketArn']}")
if outputs.get("SageMakerRoleArn"):
CONSOLE.print(f"[green]✓[/green] IAM role: {outputs['SageMakerRoleArn']}")
if cfg.mlflow.mode is MlflowMode.create and outputs.get("MlflowTrackingServerArn"):
mlflow_name = outputs.get("MlflowTrackingServerName", cfg.managed_mlflow_tracking_server_name)
CONSOLE.print(f"[green]✓[/green] MLflow: {mlflow_name}")
elif cfg.mlflow.mode is MlflowMode.existing:
CONSOLE.print(f"[green]✓[/green] MLflow: {cfg.mlflow.tracking_server_name}")
CONSOLE.print("\n[bold green]Infrastructure ready.[/bold green]")
@app.command()
def status(config: str = CONFIG_OPT) -> None:
"""Show current infrastructure status."""
cfg = load_cfg(config)
stack = cloudformation.stack_status(cfg.aws.region, cfg.aws.profile, cfg.infra.stack_name)
table = Table(title="Infrastructure Status")
table.add_column("Resource", style="cyan")
table.add_column("Name")
table.add_column("Status")
table.add_column("ARN / URI")
if not stack:
table.add_row("CDK Stack", cfg.infra.stack_name, "[red]missing[/red]", "-")
table.add_row("S3 Bucket", cfg.s3.bucket, "[red]unknown[/red]", "-")
table.add_row("IAM Role", cfg.sagemaker.role_name, "[red]unknown[/red]", "-")
if cfg.mlflow.mode is not MlflowMode.disabled:
table.add_row(
"MLflow",
cfg.effective_mlflow_tracking_server_name or "-",
"[red]unknown[/red]",
"-",
)
CONSOLE.print(table)
return
outputs = stack["outputs"]
table.add_row("CDK Stack", stack["name"], f"[green]{stack['status']}[/green]", "-")
table.add_row(
"S3 Bucket",
cfg.s3.bucket,
"[green]managed[/green]",
outputs.get("DataBucketArn", "-"),
)
table.add_row(
"IAM Role",
_role_name(cfg.sagemaker.role_name, outputs.get("SageMakerRoleArn", "")),
"[green]managed[/green]",
outputs.get("SageMakerRoleArn", "-"),
)
if cfg.mlflow.mode is MlflowMode.create:
table.add_row(
"MLflow",
outputs.get("MlflowTrackingServerName", cfg.managed_mlflow_tracking_server_name),
"[green]managed[/green]",
outputs.get("MlflowTrackingServerArn", outputs.get("MlflowArtifactUri", "-")),
)
elif cfg.mlflow.mode is MlflowMode.existing:
server = mlflow.describe_tracking_server(
cfg.aws.region,
cfg.aws.profile,
cfg.mlflow.tracking_server_name or "",
)
if server is None:
table.add_row("MLflow", cfg.mlflow.tracking_server_name or "-", "[red]missing[/red]", "-")
else:
table.add_row(
"MLflow",
cfg.mlflow.tracking_server_name or "-",
f"[green]{server.get('TrackingServerStatus', 'existing')}[/green]",
server.get("TrackingServerArn") or server.get("ArtifactStoreUri") or "-",
)
CONSOLE.print(table)
@app.command()
def destroy(
config: str = CONFIG_OPT,
yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation prompt"),
delete_bucket_data: bool = typer.Option(
False,
"--delete-bucket-data",
help="Delete stack-managed S3 objects and bucket instead of retaining data",
),
) -> None:
"""Destroy the CDK stack."""
cfg = _destroy_config(config)
stack_name = _destroy_stack_name(config, cfg)
bootstrap_qualifier = _destroy_bootstrap_qualifier(config, cfg)
toolkit_stack_name = _destroy_toolkit_stack_name(config, cfg)
if not yes and not delete_bucket_data:
typer.confirm(
f"Destroy CDK stack '{stack_name}' while retaining S3 bucket data?",
abort=True,
)
try:
account_id = _destroy_account_id(config, cfg)
with TemporaryDirectory() as temp_dir:
snapshot_path = Path(temp_dir) / "config.yaml"
snapshot_path.write_text(yaml.safe_dump(cfg.model_dump(mode="json"), sort_keys=False))
with CONSOLE.status("Running cdk destroy..."):
provisioning.destroy(
profile=cfg.aws.profile,
account_id=account_id,
stack_name=stack_name,
bootstrap_qualifier=bootstrap_qualifier,
toolkit_stack_name=toolkit_stack_name,
config_path=str(snapshot_path),
delete_bucket_data=delete_bucket_data,
)
except RuntimeError as e:
CONSOLE.print(f"[red]{e}[/red]")
raise typer.Exit(1)
CONSOLE.print(f"[green]✓[/green] Destroyed stack: {stack_name}")
CONSOLE.print(f"[yellow]CDK bootstrap stack retained: {toolkit_stack_name}[/yellow]")
def _destroy_config(config_path: str) -> Config:
config_dir = str(Path(config_path).parent)
state = read_infra_state(config_dir)
config_snapshot = state.get("config")
if config_snapshot:
return Config.model_validate(config_snapshot)
return load_cfg(config_path)
def _role_name(configured_name: str, role_arn: str) -> str:
if configured_name:
return configured_name
if role_arn:
return role_arn.rsplit("/", 1)[-1]
return "-"
def _destroy_account_id(config_path: str, cfg: Config) -> str:
config_dir = str(Path(config_path).parent)
state = read_infra_state(config_dir)
account_id = state.get("aws", {}).get("account_id")
if account_id:
return str(account_id)
return identity.account_id(cfg.aws.region, cfg.aws.profile)
def _destroy_stack_name(config_path: str, cfg: Config) -> str:
config_dir = str(Path(config_path).parent)
state = read_infra_state(config_dir)
stack_name = state.get("stack_name")
if stack_name:
return str(stack_name)
return cfg.infra.stack_name
def _destroy_bootstrap_qualifier(config_path: str, cfg: Config) -> str:
config_dir = str(Path(config_path).parent)
state = read_infra_state(config_dir)
bootstrap_qualifier = state.get("bootstrap_qualifier")
if bootstrap_qualifier:
return str(bootstrap_qualifier)
return cfg.infra.effective_bootstrap_qualifier
def _destroy_toolkit_stack_name(config_path: str, cfg: Config) -> str:
config_dir = str(Path(config_path).parent)
state = read_infra_state(config_dir)
toolkit_stack_name = state.get("toolkit_stack_name")
if toolkit_stack_name:
return str(toolkit_stack_name)
return cfg.infra.effective_toolkit_stack_name