create AWS infra
This commit is contained in:
193
src/commands/infra.py
Normal file
193
src/commands/infra.py
Normal file
@@ -0,0 +1,193 @@
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import typer
|
||||
import yaml
|
||||
from rich.table import Table
|
||||
|
||||
from src.aws import cloudformation, identity, mlflow
|
||||
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
||||
from src.config import Config, MlflowMode
|
||||
from src.infra import provisioning
|
||||
from src.infra.state import read_infra_state
|
||||
|
||||
app = typer.Typer(help="Manage AWS infrastructure")
|
||||
|
||||
|
||||
@app.command()
|
||||
def setup(
|
||||
config: str = CONFIG_OPT,
|
||||
bootstrap: bool = typer.Option(
|
||||
True,
|
||||
"--bootstrap/--no-bootstrap",
|
||||
help="Run CDK bootstrap before deploying the application stack",
|
||||
),
|
||||
) -> None:
|
||||
"""Create infrastructure with AWS CDK."""
|
||||
cfg = load_cfg(config)
|
||||
CONSOLE.print("[bold]Deploying infrastructure with AWS CDK...[/bold]")
|
||||
|
||||
try:
|
||||
account_id = identity.account_id(cfg.aws.region, cfg.aws.profile)
|
||||
if cfg.mlflow.mode is MlflowMode.existing:
|
||||
assert cfg.mlflow.tracking_server_name is not None
|
||||
with CONSOLE.status("Checking MLflow tracking server..."):
|
||||
server = mlflow.describe_tracking_server(
|
||||
cfg.aws.region,
|
||||
cfg.aws.profile,
|
||||
cfg.mlflow.tracking_server_name,
|
||||
)
|
||||
if server is None:
|
||||
raise RuntimeError(f"MLflow tracking server not found: {cfg.mlflow.tracking_server_name}")
|
||||
|
||||
if bootstrap:
|
||||
with CONSOLE.status("Running cdk bootstrap..."):
|
||||
provisioning.bootstrap(
|
||||
profile=cfg.aws.profile,
|
||||
account_id=account_id,
|
||||
region=cfg.aws.region,
|
||||
)
|
||||
with CONSOLE.status("Running cdk deploy..."):
|
||||
state = provisioning.deploy(
|
||||
profile=cfg.aws.profile,
|
||||
account_id=account_id,
|
||||
region=cfg.aws.region,
|
||||
config_path=config,
|
||||
config_dir=str(Path(config).parent),
|
||||
config_snapshot=cfg.model_dump(mode="json"),
|
||||
)
|
||||
except RuntimeError as e:
|
||||
CONSOLE.print(f"[red]{e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
outputs = state.get("outputs", {})
|
||||
if outputs.get("DataBucketArn"):
|
||||
CONSOLE.print(f"[green]✓[/green] S3 bucket: {outputs['DataBucketArn']}")
|
||||
if outputs.get("SageMakerRoleArn"):
|
||||
CONSOLE.print(f"[green]✓[/green] IAM role: {outputs['SageMakerRoleArn']}")
|
||||
if cfg.mlflow.mode is MlflowMode.create and outputs.get("MlflowTrackingServerArn"):
|
||||
CONSOLE.print(f"[green]✓[/green] MLflow: {outputs['MlflowTrackingServerArn']}")
|
||||
elif cfg.mlflow.mode is MlflowMode.existing:
|
||||
CONSOLE.print(f"[green]✓[/green] MLflow: {cfg.mlflow.tracking_server_name}")
|
||||
CONSOLE.print("\n[bold green]Infrastructure ready.[/bold green]")
|
||||
|
||||
|
||||
@app.command()
|
||||
def status(config: str = CONFIG_OPT) -> None:
|
||||
"""Show current infrastructure status."""
|
||||
cfg = load_cfg(config)
|
||||
stack = cloudformation.stack_status(cfg.aws.region, cfg.aws.profile)
|
||||
|
||||
table = Table(title="Infrastructure Status")
|
||||
table.add_column("Resource", style="cyan")
|
||||
table.add_column("Name")
|
||||
table.add_column("Status")
|
||||
table.add_column("ARN / URI")
|
||||
|
||||
if not stack:
|
||||
table.add_row("CDK Stack", provisioning.STACK_NAME, "[red]missing[/red]", "-")
|
||||
table.add_row("S3 Bucket", cfg.s3.bucket, "[red]unknown[/red]", "-")
|
||||
table.add_row("IAM Role", cfg.sagemaker.role_name, "[red]unknown[/red]", "-")
|
||||
if cfg.mlflow.mode is not MlflowMode.disabled:
|
||||
table.add_row(
|
||||
"MLflow",
|
||||
cfg.mlflow.tracking_server_name or "-",
|
||||
"[red]unknown[/red]",
|
||||
"-",
|
||||
)
|
||||
CONSOLE.print(table)
|
||||
return
|
||||
|
||||
outputs = stack["outputs"]
|
||||
table.add_row("CDK Stack", stack["name"], f"[green]{stack['status']}[/green]", "-")
|
||||
table.add_row(
|
||||
"S3 Bucket",
|
||||
cfg.s3.bucket,
|
||||
"[green]managed[/green]",
|
||||
outputs.get("DataBucketArn", "-"),
|
||||
)
|
||||
table.add_row(
|
||||
"IAM Role",
|
||||
cfg.sagemaker.role_name,
|
||||
"[green]managed[/green]",
|
||||
outputs.get("SageMakerRoleArn", "-"),
|
||||
)
|
||||
if cfg.mlflow.mode is MlflowMode.create:
|
||||
table.add_row(
|
||||
"MLflow",
|
||||
cfg.mlflow.tracking_server_name or "-",
|
||||
"[green]managed[/green]",
|
||||
outputs.get("MlflowTrackingServerArn", outputs.get("MlflowArtifactUri", "-")),
|
||||
)
|
||||
elif cfg.mlflow.mode is MlflowMode.existing:
|
||||
server = mlflow.describe_tracking_server(
|
||||
cfg.aws.region,
|
||||
cfg.aws.profile,
|
||||
cfg.mlflow.tracking_server_name or "",
|
||||
)
|
||||
if server is None:
|
||||
table.add_row("MLflow", cfg.mlflow.tracking_server_name or "-", "[red]missing[/red]", "-")
|
||||
else:
|
||||
table.add_row(
|
||||
"MLflow",
|
||||
cfg.mlflow.tracking_server_name or "-",
|
||||
f"[green]{server.get('TrackingServerStatus', 'existing')}[/green]",
|
||||
server.get("TrackingServerArn") or server.get("ArtifactStoreUri") or "-",
|
||||
)
|
||||
|
||||
CONSOLE.print(table)
|
||||
|
||||
|
||||
@app.command()
|
||||
def destroy(
|
||||
config: str = CONFIG_OPT,
|
||||
yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation prompt"),
|
||||
delete_bucket_data: bool = typer.Option(
|
||||
False,
|
||||
"--delete-bucket-data",
|
||||
help="Delete stack-managed S3 objects and bucket instead of retaining data",
|
||||
),
|
||||
) -> None:
|
||||
"""Destroy the CDK stack."""
|
||||
cfg = _destroy_config(config)
|
||||
|
||||
if not yes and not delete_bucket_data:
|
||||
typer.confirm(
|
||||
f"Destroy CDK stack '{provisioning.STACK_NAME}' while retaining S3 bucket data?",
|
||||
abort=True,
|
||||
)
|
||||
|
||||
try:
|
||||
account_id = _destroy_account_id(config, cfg)
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
snapshot_path = Path(temp_dir) / "config.yaml"
|
||||
snapshot_path.write_text(yaml.safe_dump(cfg.model_dump(mode="json"), sort_keys=False))
|
||||
with CONSOLE.status("Running cdk destroy..."):
|
||||
provisioning.destroy(
|
||||
profile=cfg.aws.profile,
|
||||
account_id=account_id,
|
||||
config_path=str(snapshot_path),
|
||||
delete_bucket_data=delete_bucket_data,
|
||||
)
|
||||
except RuntimeError as e:
|
||||
CONSOLE.print(f"[red]{e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
CONSOLE.print(f"[green]✓[/green] Destroyed stack: {provisioning.STACK_NAME}")
|
||||
|
||||
|
||||
def _destroy_config(config_path: str) -> Config:
|
||||
config_dir = str(Path(config_path).parent)
|
||||
state = read_infra_state(config_dir)
|
||||
config_snapshot = state.get("config")
|
||||
if config_snapshot:
|
||||
return Config.model_validate(config_snapshot)
|
||||
return load_cfg(config_path)
|
||||
|
||||
|
||||
def _destroy_account_id(config_path: str, cfg: Config) -> str:
|
||||
config_dir = str(Path(config_path).parent)
|
||||
state = read_infra_state(config_dir)
|
||||
account_id = state.get("aws", {}).get("account_id")
|
||||
if account_id:
|
||||
return str(account_id)
|
||||
return identity.account_id(cfg.aws.region, cfg.aws.profile)
|
||||
Reference in New Issue
Block a user