create AWS infra

This commit is contained in:
2026-05-15 10:26:43 -04:00
parent 6563b4cc4b
commit 1bc5052d22
21 changed files with 1502 additions and 0 deletions

0
src/__init__.py Normal file
View File

0
src/aws/__init__.py Normal file
View File

22
src/aws/cloudformation.py Normal file
View File

@@ -0,0 +1,22 @@
from typing import Any
import boto3
from botocore.exceptions import ClientError
from src.infra.provisioning import STACK_NAME
def stack_status(region: str, profile: str) -> dict[str, Any] | None:
client = boto3.Session(profile_name=profile, region_name=region).client("cloudformation")
try:
stack = client.describe_stacks(StackName=STACK_NAME)["Stacks"][0]
except ClientError as e:
message = e.response.get("Error", {}).get("Message", "")
if "does not exist" in message:
return None
raise
return {
"name": stack["StackName"],
"status": stack["StackStatus"],
"outputs": {item["OutputKey"]: item.get("OutputValue", "") for item in stack.get("Outputs", [])},
}

5
src/aws/identity.py Normal file
View File

@@ -0,0 +1,5 @@
import boto3
def account_id(region: str, profile: str) -> str:
return boto3.Session(profile_name=profile, region_name=region).client("sts").get_caller_identity()["Account"]

19
src/aws/mlflow.py Normal file
View File

@@ -0,0 +1,19 @@
from typing import Any, cast
import boto3
from botocore.exceptions import ClientError
def describe_tracking_server(region: str, profile: str, name: str) -> dict[str, Any] | None:
client = boto3.Session(profile_name=profile, region_name=region).client("sagemaker")
try:
return cast(dict[str, Any], client.describe_mlflow_tracking_server(TrackingServerName=name))
except ClientError as e:
code = e.response.get("Error", {}).get("Code", "")
message = e.response.get("Error", {}).get("Message", "")
if (
code in {"ResourceNotFound", "ResourceNotFoundException", "ValidationException"}
or "not found" in message.lower()
):
return None
raise

0
src/commands/__init__.py Normal file
View File

193
src/commands/infra.py Normal file
View File

@@ -0,0 +1,193 @@
from pathlib import Path
from tempfile import TemporaryDirectory
import typer
import yaml
from rich.table import Table
from src.aws import cloudformation, identity, mlflow
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
from src.config import Config, MlflowMode
from src.infra import provisioning
from src.infra.state import read_infra_state
app = typer.Typer(help="Manage AWS infrastructure")
@app.command()
def setup(
config: str = CONFIG_OPT,
bootstrap: bool = typer.Option(
True,
"--bootstrap/--no-bootstrap",
help="Run CDK bootstrap before deploying the application stack",
),
) -> None:
"""Create infrastructure with AWS CDK."""
cfg = load_cfg(config)
CONSOLE.print("[bold]Deploying infrastructure with AWS CDK...[/bold]")
try:
account_id = identity.account_id(cfg.aws.region, cfg.aws.profile)
if cfg.mlflow.mode is MlflowMode.existing:
assert cfg.mlflow.tracking_server_name is not None
with CONSOLE.status("Checking MLflow tracking server..."):
server = mlflow.describe_tracking_server(
cfg.aws.region,
cfg.aws.profile,
cfg.mlflow.tracking_server_name,
)
if server is None:
raise RuntimeError(f"MLflow tracking server not found: {cfg.mlflow.tracking_server_name}")
if bootstrap:
with CONSOLE.status("Running cdk bootstrap..."):
provisioning.bootstrap(
profile=cfg.aws.profile,
account_id=account_id,
region=cfg.aws.region,
)
with CONSOLE.status("Running cdk deploy..."):
state = provisioning.deploy(
profile=cfg.aws.profile,
account_id=account_id,
region=cfg.aws.region,
config_path=config,
config_dir=str(Path(config).parent),
config_snapshot=cfg.model_dump(mode="json"),
)
except RuntimeError as e:
CONSOLE.print(f"[red]{e}[/red]")
raise typer.Exit(1)
outputs = state.get("outputs", {})
if outputs.get("DataBucketArn"):
CONSOLE.print(f"[green]✓[/green] S3 bucket: {outputs['DataBucketArn']}")
if outputs.get("SageMakerRoleArn"):
CONSOLE.print(f"[green]✓[/green] IAM role: {outputs['SageMakerRoleArn']}")
if cfg.mlflow.mode is MlflowMode.create and outputs.get("MlflowTrackingServerArn"):
CONSOLE.print(f"[green]✓[/green] MLflow: {outputs['MlflowTrackingServerArn']}")
elif cfg.mlflow.mode is MlflowMode.existing:
CONSOLE.print(f"[green]✓[/green] MLflow: {cfg.mlflow.tracking_server_name}")
CONSOLE.print("\n[bold green]Infrastructure ready.[/bold green]")
@app.command()
def status(config: str = CONFIG_OPT) -> None:
"""Show current infrastructure status."""
cfg = load_cfg(config)
stack = cloudformation.stack_status(cfg.aws.region, cfg.aws.profile)
table = Table(title="Infrastructure Status")
table.add_column("Resource", style="cyan")
table.add_column("Name")
table.add_column("Status")
table.add_column("ARN / URI")
if not stack:
table.add_row("CDK Stack", provisioning.STACK_NAME, "[red]missing[/red]", "-")
table.add_row("S3 Bucket", cfg.s3.bucket, "[red]unknown[/red]", "-")
table.add_row("IAM Role", cfg.sagemaker.role_name, "[red]unknown[/red]", "-")
if cfg.mlflow.mode is not MlflowMode.disabled:
table.add_row(
"MLflow",
cfg.mlflow.tracking_server_name or "-",
"[red]unknown[/red]",
"-",
)
CONSOLE.print(table)
return
outputs = stack["outputs"]
table.add_row("CDK Stack", stack["name"], f"[green]{stack['status']}[/green]", "-")
table.add_row(
"S3 Bucket",
cfg.s3.bucket,
"[green]managed[/green]",
outputs.get("DataBucketArn", "-"),
)
table.add_row(
"IAM Role",
cfg.sagemaker.role_name,
"[green]managed[/green]",
outputs.get("SageMakerRoleArn", "-"),
)
if cfg.mlflow.mode is MlflowMode.create:
table.add_row(
"MLflow",
cfg.mlflow.tracking_server_name or "-",
"[green]managed[/green]",
outputs.get("MlflowTrackingServerArn", outputs.get("MlflowArtifactUri", "-")),
)
elif cfg.mlflow.mode is MlflowMode.existing:
server = mlflow.describe_tracking_server(
cfg.aws.region,
cfg.aws.profile,
cfg.mlflow.tracking_server_name or "",
)
if server is None:
table.add_row("MLflow", cfg.mlflow.tracking_server_name or "-", "[red]missing[/red]", "-")
else:
table.add_row(
"MLflow",
cfg.mlflow.tracking_server_name or "-",
f"[green]{server.get('TrackingServerStatus', 'existing')}[/green]",
server.get("TrackingServerArn") or server.get("ArtifactStoreUri") or "-",
)
CONSOLE.print(table)
@app.command()
def destroy(
config: str = CONFIG_OPT,
yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation prompt"),
delete_bucket_data: bool = typer.Option(
False,
"--delete-bucket-data",
help="Delete stack-managed S3 objects and bucket instead of retaining data",
),
) -> None:
"""Destroy the CDK stack."""
cfg = _destroy_config(config)
if not yes and not delete_bucket_data:
typer.confirm(
f"Destroy CDK stack '{provisioning.STACK_NAME}' while retaining S3 bucket data?",
abort=True,
)
try:
account_id = _destroy_account_id(config, cfg)
with TemporaryDirectory() as temp_dir:
snapshot_path = Path(temp_dir) / "config.yaml"
snapshot_path.write_text(yaml.safe_dump(cfg.model_dump(mode="json"), sort_keys=False))
with CONSOLE.status("Running cdk destroy..."):
provisioning.destroy(
profile=cfg.aws.profile,
account_id=account_id,
config_path=str(snapshot_path),
delete_bucket_data=delete_bucket_data,
)
except RuntimeError as e:
CONSOLE.print(f"[red]{e}[/red]")
raise typer.Exit(1)
CONSOLE.print(f"[green]✓[/green] Destroyed stack: {provisioning.STACK_NAME}")
def _destroy_config(config_path: str) -> Config:
config_dir = str(Path(config_path).parent)
state = read_infra_state(config_dir)
config_snapshot = state.get("config")
if config_snapshot:
return Config.model_validate(config_snapshot)
return load_cfg(config_path)
def _destroy_account_id(config_path: str, cfg: Config) -> str:
config_dir = str(Path(config_path).parent)
state = read_infra_state(config_dir)
account_id = state.get("aws", {}).get("account_id")
if account_id:
return str(account_id)
return identity.account_id(cfg.aws.region, cfg.aws.profile)

29
src/commands/utils.py Normal file
View File

@@ -0,0 +1,29 @@
from pathlib import Path
import typer
import yaml
from rich.console import Console
from src.config import Config
CONSOLE = Console()
CONFIG_OPT = typer.Option("config.yaml", "--config", "-c", help="Path to config file")
def load_config(path: str = "config.yaml") -> Config:
config_path = Path(path)
if not config_path.exists():
raise FileNotFoundError(
f"Config file not found: {config_path}. Run 'qai-cli init' to create one."
)
with open(config_path) as f:
data = yaml.safe_load(f)
return Config.model_validate(data)
def load_cfg(path: str = "config.yaml") -> Config:
try:
return load_config(path)
except FileNotFoundError as e:
CONSOLE.print(f"[red]{e}[/red]")
raise typer.Exit(1)

66
src/config.py Normal file
View File

@@ -0,0 +1,66 @@
from enum import Enum
from typing import Any, Literal
from mypy_boto3_s3.literals import BucketLocationConstraintType
from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
from pydantic import BaseModel, Field, model_validator
class MlflowMode(str, Enum):
disabled = "disabled"
create = "create"
existing = "existing"
class MlflowServerSize(str, Enum):
small = "Small"
medium = "Medium"
large = "Large"
class AwsConfig(BaseModel):
region: BucketLocationConstraintType | Literal["us-east-1"] = "us-east-1"
profile: str = "default"
class S3Config(BaseModel):
bucket: str = "my-onnx-bucket"
data_prefix: str = "data/"
model_prefix: str = "models/"
class TrainingConfig(BaseModel):
instance_type: TrainingInstanceTypeType = "ml.m5.xlarge"
instance_count: int = 1
image_uri: str = ""
entry_point: str | None = None
source_dir: str | None = None
hyperparameters: dict[str, Any] = Field(default_factory=dict)
class SageMakerConfig(BaseModel):
role_name: str = "qai-cli-sagemaker-role"
training: TrainingConfig = Field(default_factory=TrainingConfig)
class MlflowConfig(BaseModel):
mode: MlflowMode = MlflowMode.disabled
tracking_server_name: str | None = None
artifact_prefix: str = "mlflow/"
tracking_server_size: MlflowServerSize = MlflowServerSize.small
mlflow_version: str | None = None
automatic_model_registration: bool = False
weekly_maintenance_window_start: str | None = None
@model_validator(mode="after")
def require_tracking_server_name(self) -> "MlflowConfig":
if self.mode in {MlflowMode.create, MlflowMode.existing} and not self.tracking_server_name:
raise ValueError("mlflow.tracking_server_name is required when mlflow.mode is create or existing")
return self
class Config(BaseModel):
aws: AwsConfig = Field(default_factory=AwsConfig)
s3: S3Config = Field(default_factory=S3Config)
sagemaker: SageMakerConfig = Field(default_factory=SageMakerConfig)
mlflow: MlflowConfig = Field(default_factory=MlflowConfig)

0
src/infra/__init__.py Normal file
View File

118
src/infra/provisioning.py Normal file
View File

@@ -0,0 +1,118 @@
import json
import subprocess
from pathlib import Path
from typing import Any
from src.infra.state import state_path, write_infra_state
STACK_NAME = "QaiCliStack"
def bootstrap(*, profile: str, account_id: str, region: str) -> None:
_run(["cdk", "bootstrap", f"aws://{account_id}/{region}", "--profile", profile])
def deploy(
*,
profile: str,
account_id: str,
region: str,
config_path: str,
config_dir: str,
config_snapshot: dict[str, Any],
) -> dict[str, Any]:
outputs_file = state_path(config_dir).with_suffix(".cdk-outputs.json")
cmd = _cdk_cmd(
"deploy",
profile=profile,
account_id=account_id,
config_path=config_path,
delete_bucket_data=False,
) + ["--require-approval", "never", "--outputs-file", str(outputs_file)]
_run(cmd)
outputs = _read_outputs(outputs_file)
state = {
"stack_name": STACK_NAME,
"aws": {
"account_id": account_id,
"region": region,
"profile": profile,
},
"config": config_snapshot,
"outputs": outputs,
}
write_infra_state(config_dir, state)
return state
def destroy(
*,
profile: str,
account_id: str,
config_path: str,
delete_bucket_data: bool,
) -> None:
if delete_bucket_data:
update_cmd = _cdk_cmd(
"deploy",
profile=profile,
account_id=account_id,
config_path=config_path,
delete_bucket_data=True,
) + ["--require-approval", "never"]
_run(update_cmd)
cmd = _cdk_cmd(
"destroy",
profile=profile,
account_id=account_id,
config_path=config_path,
delete_bucket_data=delete_bucket_data,
) + ["--force"]
_run(cmd)
def _cdk_cmd(
action: str,
*,
profile: str,
account_id: str,
config_path: str,
delete_bucket_data: bool,
) -> list[str]:
cmd = [
"cdk",
action,
STACK_NAME,
"--app",
"python app.py",
"--profile",
profile,
"-c",
f"account_id={account_id}",
"-c",
f"config={config_path}",
"-c",
f"stack_name={STACK_NAME}",
"-c",
f"delete_bucket_data={str(delete_bucket_data).lower()}",
]
return cmd
def _run(cmd: list[str]) -> None:
try:
subprocess.run(cmd, check=True)
except FileNotFoundError as e:
raise RuntimeError("CDK CLI not found. Install it with: npm install -g aws-cdk") from e
except subprocess.CalledProcessError as e:
raise RuntimeError(f"CDK command failed with exit code {e.returncode}.") from e
def _read_outputs(path: Path) -> dict[str, str]:
if not path.exists():
return {}
with open(path) as f:
data = json.load(f)
return data.get(STACK_NAME, {})

169
src/infra/stack.py Normal file
View File

@@ -0,0 +1,169 @@
from typing import Any
from aws_cdk import CfnOutput, RemovalPolicy, Stack
from aws_cdk import aws_iam as iam
from aws_cdk import aws_s3 as s3
from aws_cdk import aws_sagemaker as sagemaker
from constructs import Construct
from src.config import Config, MlflowMode
class QaiStack(Stack):
def __init__(
self,
scope: Construct,
construct_id: str,
*,
config: Config,
delete_bucket_data: bool = False,
**kwargs,
) -> None:
super().__init__(scope, construct_id, **kwargs)
removal_policy = RemovalPolicy.DESTROY if delete_bucket_data else RemovalPolicy.RETAIN
data_bucket = s3.Bucket(
self,
"DataBucket",
bucket_name=config.s3.bucket,
versioned=True,
removal_policy=removal_policy,
auto_delete_objects=delete_bucket_data,
)
role = iam.CfnRole(
self,
"SageMakerRole",
role_name=config.sagemaker.role_name,
assume_role_policy_document=self._sagemaker_trust_policy(),
managed_policy_arns=[
f"arn:{self.partition}:iam::aws:policy/AmazonSageMakerFullAccess",
],
)
iam.CfnPolicy(
self,
"SageMakerRoleS3Policy",
roles=[role.ref],
policy_name="SageMakerRoleS3Policy",
policy_document={
"Version": "2012-10-17",
"Statement": [{
"Effect": "Allow",
"Action": [
"s3:GetObject*",
"s3:GetBucket*",
"s3:List*",
"s3:DeleteObject*",
"s3:PutObject",
"s3:PutObjectLegalHold",
"s3:PutObjectRetention",
"s3:PutObjectTagging",
"s3:PutObjectVersionTagging",
"s3:Abort*",
],
"Resource": [
data_bucket.bucket_arn,
f"{data_bucket.bucket_arn}/*",
],
}],
},
)
CfnOutput(self, "DataBucketName", value=data_bucket.bucket_name)
CfnOutput(self, "DataBucketArn", value=data_bucket.bucket_arn)
CfnOutput(self, "SageMakerRoleArn", value=role.attr_arn)
if config.mlflow.mode is MlflowMode.create:
artifact_prefix = config.mlflow.artifact_prefix.strip("/")
artifact_uri = (
f"s3://{data_bucket.bucket_name}/{artifact_prefix}/"
if artifact_prefix
else f"s3://{data_bucket.bucket_name}/"
)
mlflow_role = iam.CfnRole(
self,
"MlflowRole",
assume_role_policy_document=self._sagemaker_trust_policy(),
)
s3_statement: dict[str, Any] = {
"Effect": "Allow",
"Action": [
"s3:GetObject*",
"s3:GetBucket*",
"s3:List*",
"s3:DeleteObject*",
"s3:PutObject",
"s3:PutObjectLegalHold",
"s3:PutObjectRetention",
"s3:PutObjectTagging",
"s3:PutObjectVersionTagging",
"s3:Abort*",
],
"Resource": [
data_bucket.bucket_arn,
(
f"{data_bucket.bucket_arn}/{artifact_prefix}/*"
if artifact_prefix
else f"{data_bucket.bucket_arn}/*"
),
],
}
list_statement: dict[str, Any] = {
"Effect": "Allow",
"Action": "s3:ListBucket",
"Resource": data_bucket.bucket_arn,
}
if artifact_prefix:
list_statement["Condition"] = {"StringLike": {"s3:prefix": [f"{artifact_prefix}/*"]}}
iam.CfnPolicy(
self,
"MlflowRolePolicy",
roles=[mlflow_role.ref],
policy_name="MlflowRolePolicy",
policy_document={
"Version": "2012-10-17",
"Statement": [
s3_statement,
list_statement,
{
"Effect": "Allow",
"Action": [
"sagemaker:AddTags",
"sagemaker:CreateModelPackageGroup",
"sagemaker:CreateModelPackage",
"sagemaker:UpdateModelPackage",
"sagemaker:DescribeModelPackageGroup",
],
"Resource": "*",
},
],
},
)
tracking_server = sagemaker.CfnMlflowTrackingServer(
self,
"MlflowTrackingServer",
artifact_store_uri=artifact_uri,
role_arn=mlflow_role.attr_arn,
tracking_server_name=config.mlflow.tracking_server_name or "",
automatic_model_registration=config.mlflow.automatic_model_registration,
mlflow_version=config.mlflow.mlflow_version,
tracking_server_size=config.mlflow.tracking_server_size.value,
weekly_maintenance_window_start=config.mlflow.weekly_maintenance_window_start,
)
CfnOutput(self, "MlflowTrackingServerName", value=config.mlflow.tracking_server_name or "")
CfnOutput(self, "MlflowTrackingServerArn", value=tracking_server.attr_tracking_server_arn)
CfnOutput(self, "MlflowArtifactUri", value=artifact_uri)
CfnOutput(self, "MlflowRoleArn", value=mlflow_role.attr_arn)
@staticmethod
def _sagemaker_trust_policy() -> dict[str, Any]:
return {
"Version": "2012-10-17",
"Statement": [{
"Effect": "Allow",
"Principal": {"Service": "sagemaker.amazonaws.com"},
"Action": "sts:AssumeRole",
}],
}

22
src/infra/state.py Normal file
View File

@@ -0,0 +1,22 @@
import json
from pathlib import Path
from typing import Any
INFRA_STATE_FILE = ".qai-cli-infra.json"
def state_path(config_dir: str) -> Path:
return Path(config_dir) / INFRA_STATE_FILE
def read_infra_state(config_dir: str) -> dict[str, Any]:
path = state_path(config_dir)
if not path.exists():
return {}
with open(path) as f:
return json.load(f)
def write_infra_state(config_dir: str, state: dict[str, Any]) -> None:
with open(state_path(config_dir), "w") as f:
json.dump(state, f, indent=2)

39
src/main.py Normal file
View File

@@ -0,0 +1,39 @@
from pathlib import Path
import typer
import yaml
from rich.console import Console
from src.commands import infra
from src.config import Config
app = typer.Typer(
help="qai-cli: End-to-end model managment for Qualcomm AI Hub.",
no_args_is_help=True,
)
app.add_typer(infra.app, name="infra")
console = Console()
@app.command()
def init(
output: str = typer.Option("config.yaml", help="Destination path for the config file"),
force: bool = typer.Option(False, "--force", "-f", help="Overwrite an existing config file"),
) -> None:
"""Write a starter config.yaml to the current directory."""
dest = Path(output)
if dest.exists() and not force:
console.print(f"[yellow]{dest} already exists.[/yellow] Use --force to overwrite.")
raise typer.Exit(1)
config = Config()
dest.parent.mkdir(parents=True, exist_ok=True)
with open(dest, "w") as f:
yaml.safe_dump(config.model_dump(mode="json"), f, sort_keys=False)
console.print(f"[green]✓[/green] Config written to [bold]{dest}[/bold]")
console.print(
"Edit it (especially [cyan]s3.bucket[/cyan] and [cyan]sagemaker.training.image_uri[/cyan]) "
"before running other commands."
)