create AWS infra
This commit is contained in:
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
0
src/aws/__init__.py
Normal file
0
src/aws/__init__.py
Normal file
22
src/aws/cloudformation.py
Normal file
22
src/aws/cloudformation.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Any
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from src.infra.provisioning import STACK_NAME
|
||||
|
||||
|
||||
def stack_status(region: str, profile: str) -> dict[str, Any] | None:
|
||||
client = boto3.Session(profile_name=profile, region_name=region).client("cloudformation")
|
||||
try:
|
||||
stack = client.describe_stacks(StackName=STACK_NAME)["Stacks"][0]
|
||||
except ClientError as e:
|
||||
message = e.response.get("Error", {}).get("Message", "")
|
||||
if "does not exist" in message:
|
||||
return None
|
||||
raise
|
||||
return {
|
||||
"name": stack["StackName"],
|
||||
"status": stack["StackStatus"],
|
||||
"outputs": {item["OutputKey"]: item.get("OutputValue", "") for item in stack.get("Outputs", [])},
|
||||
}
|
||||
5
src/aws/identity.py
Normal file
5
src/aws/identity.py
Normal file
@@ -0,0 +1,5 @@
|
||||
import boto3
|
||||
|
||||
|
||||
def account_id(region: str, profile: str) -> str:
|
||||
return boto3.Session(profile_name=profile, region_name=region).client("sts").get_caller_identity()["Account"]
|
||||
19
src/aws/mlflow.py
Normal file
19
src/aws/mlflow.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from typing import Any, cast
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
|
||||
def describe_tracking_server(region: str, profile: str, name: str) -> dict[str, Any] | None:
|
||||
client = boto3.Session(profile_name=profile, region_name=region).client("sagemaker")
|
||||
try:
|
||||
return cast(dict[str, Any], client.describe_mlflow_tracking_server(TrackingServerName=name))
|
||||
except ClientError as e:
|
||||
code = e.response.get("Error", {}).get("Code", "")
|
||||
message = e.response.get("Error", {}).get("Message", "")
|
||||
if (
|
||||
code in {"ResourceNotFound", "ResourceNotFoundException", "ValidationException"}
|
||||
or "not found" in message.lower()
|
||||
):
|
||||
return None
|
||||
raise
|
||||
0
src/commands/__init__.py
Normal file
0
src/commands/__init__.py
Normal file
193
src/commands/infra.py
Normal file
193
src/commands/infra.py
Normal file
@@ -0,0 +1,193 @@
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import typer
|
||||
import yaml
|
||||
from rich.table import Table
|
||||
|
||||
from src.aws import cloudformation, identity, mlflow
|
||||
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
||||
from src.config import Config, MlflowMode
|
||||
from src.infra import provisioning
|
||||
from src.infra.state import read_infra_state
|
||||
|
||||
app = typer.Typer(help="Manage AWS infrastructure")
|
||||
|
||||
|
||||
@app.command()
|
||||
def setup(
|
||||
config: str = CONFIG_OPT,
|
||||
bootstrap: bool = typer.Option(
|
||||
True,
|
||||
"--bootstrap/--no-bootstrap",
|
||||
help="Run CDK bootstrap before deploying the application stack",
|
||||
),
|
||||
) -> None:
|
||||
"""Create infrastructure with AWS CDK."""
|
||||
cfg = load_cfg(config)
|
||||
CONSOLE.print("[bold]Deploying infrastructure with AWS CDK...[/bold]")
|
||||
|
||||
try:
|
||||
account_id = identity.account_id(cfg.aws.region, cfg.aws.profile)
|
||||
if cfg.mlflow.mode is MlflowMode.existing:
|
||||
assert cfg.mlflow.tracking_server_name is not None
|
||||
with CONSOLE.status("Checking MLflow tracking server..."):
|
||||
server = mlflow.describe_tracking_server(
|
||||
cfg.aws.region,
|
||||
cfg.aws.profile,
|
||||
cfg.mlflow.tracking_server_name,
|
||||
)
|
||||
if server is None:
|
||||
raise RuntimeError(f"MLflow tracking server not found: {cfg.mlflow.tracking_server_name}")
|
||||
|
||||
if bootstrap:
|
||||
with CONSOLE.status("Running cdk bootstrap..."):
|
||||
provisioning.bootstrap(
|
||||
profile=cfg.aws.profile,
|
||||
account_id=account_id,
|
||||
region=cfg.aws.region,
|
||||
)
|
||||
with CONSOLE.status("Running cdk deploy..."):
|
||||
state = provisioning.deploy(
|
||||
profile=cfg.aws.profile,
|
||||
account_id=account_id,
|
||||
region=cfg.aws.region,
|
||||
config_path=config,
|
||||
config_dir=str(Path(config).parent),
|
||||
config_snapshot=cfg.model_dump(mode="json"),
|
||||
)
|
||||
except RuntimeError as e:
|
||||
CONSOLE.print(f"[red]{e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
outputs = state.get("outputs", {})
|
||||
if outputs.get("DataBucketArn"):
|
||||
CONSOLE.print(f"[green]✓[/green] S3 bucket: {outputs['DataBucketArn']}")
|
||||
if outputs.get("SageMakerRoleArn"):
|
||||
CONSOLE.print(f"[green]✓[/green] IAM role: {outputs['SageMakerRoleArn']}")
|
||||
if cfg.mlflow.mode is MlflowMode.create and outputs.get("MlflowTrackingServerArn"):
|
||||
CONSOLE.print(f"[green]✓[/green] MLflow: {outputs['MlflowTrackingServerArn']}")
|
||||
elif cfg.mlflow.mode is MlflowMode.existing:
|
||||
CONSOLE.print(f"[green]✓[/green] MLflow: {cfg.mlflow.tracking_server_name}")
|
||||
CONSOLE.print("\n[bold green]Infrastructure ready.[/bold green]")
|
||||
|
||||
|
||||
@app.command()
|
||||
def status(config: str = CONFIG_OPT) -> None:
|
||||
"""Show current infrastructure status."""
|
||||
cfg = load_cfg(config)
|
||||
stack = cloudformation.stack_status(cfg.aws.region, cfg.aws.profile)
|
||||
|
||||
table = Table(title="Infrastructure Status")
|
||||
table.add_column("Resource", style="cyan")
|
||||
table.add_column("Name")
|
||||
table.add_column("Status")
|
||||
table.add_column("ARN / URI")
|
||||
|
||||
if not stack:
|
||||
table.add_row("CDK Stack", provisioning.STACK_NAME, "[red]missing[/red]", "-")
|
||||
table.add_row("S3 Bucket", cfg.s3.bucket, "[red]unknown[/red]", "-")
|
||||
table.add_row("IAM Role", cfg.sagemaker.role_name, "[red]unknown[/red]", "-")
|
||||
if cfg.mlflow.mode is not MlflowMode.disabled:
|
||||
table.add_row(
|
||||
"MLflow",
|
||||
cfg.mlflow.tracking_server_name or "-",
|
||||
"[red]unknown[/red]",
|
||||
"-",
|
||||
)
|
||||
CONSOLE.print(table)
|
||||
return
|
||||
|
||||
outputs = stack["outputs"]
|
||||
table.add_row("CDK Stack", stack["name"], f"[green]{stack['status']}[/green]", "-")
|
||||
table.add_row(
|
||||
"S3 Bucket",
|
||||
cfg.s3.bucket,
|
||||
"[green]managed[/green]",
|
||||
outputs.get("DataBucketArn", "-"),
|
||||
)
|
||||
table.add_row(
|
||||
"IAM Role",
|
||||
cfg.sagemaker.role_name,
|
||||
"[green]managed[/green]",
|
||||
outputs.get("SageMakerRoleArn", "-"),
|
||||
)
|
||||
if cfg.mlflow.mode is MlflowMode.create:
|
||||
table.add_row(
|
||||
"MLflow",
|
||||
cfg.mlflow.tracking_server_name or "-",
|
||||
"[green]managed[/green]",
|
||||
outputs.get("MlflowTrackingServerArn", outputs.get("MlflowArtifactUri", "-")),
|
||||
)
|
||||
elif cfg.mlflow.mode is MlflowMode.existing:
|
||||
server = mlflow.describe_tracking_server(
|
||||
cfg.aws.region,
|
||||
cfg.aws.profile,
|
||||
cfg.mlflow.tracking_server_name or "",
|
||||
)
|
||||
if server is None:
|
||||
table.add_row("MLflow", cfg.mlflow.tracking_server_name or "-", "[red]missing[/red]", "-")
|
||||
else:
|
||||
table.add_row(
|
||||
"MLflow",
|
||||
cfg.mlflow.tracking_server_name or "-",
|
||||
f"[green]{server.get('TrackingServerStatus', 'existing')}[/green]",
|
||||
server.get("TrackingServerArn") or server.get("ArtifactStoreUri") or "-",
|
||||
)
|
||||
|
||||
CONSOLE.print(table)
|
||||
|
||||
|
||||
@app.command()
|
||||
def destroy(
|
||||
config: str = CONFIG_OPT,
|
||||
yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation prompt"),
|
||||
delete_bucket_data: bool = typer.Option(
|
||||
False,
|
||||
"--delete-bucket-data",
|
||||
help="Delete stack-managed S3 objects and bucket instead of retaining data",
|
||||
),
|
||||
) -> None:
|
||||
"""Destroy the CDK stack."""
|
||||
cfg = _destroy_config(config)
|
||||
|
||||
if not yes and not delete_bucket_data:
|
||||
typer.confirm(
|
||||
f"Destroy CDK stack '{provisioning.STACK_NAME}' while retaining S3 bucket data?",
|
||||
abort=True,
|
||||
)
|
||||
|
||||
try:
|
||||
account_id = _destroy_account_id(config, cfg)
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
snapshot_path = Path(temp_dir) / "config.yaml"
|
||||
snapshot_path.write_text(yaml.safe_dump(cfg.model_dump(mode="json"), sort_keys=False))
|
||||
with CONSOLE.status("Running cdk destroy..."):
|
||||
provisioning.destroy(
|
||||
profile=cfg.aws.profile,
|
||||
account_id=account_id,
|
||||
config_path=str(snapshot_path),
|
||||
delete_bucket_data=delete_bucket_data,
|
||||
)
|
||||
except RuntimeError as e:
|
||||
CONSOLE.print(f"[red]{e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
CONSOLE.print(f"[green]✓[/green] Destroyed stack: {provisioning.STACK_NAME}")
|
||||
|
||||
|
||||
def _destroy_config(config_path: str) -> Config:
|
||||
config_dir = str(Path(config_path).parent)
|
||||
state = read_infra_state(config_dir)
|
||||
config_snapshot = state.get("config")
|
||||
if config_snapshot:
|
||||
return Config.model_validate(config_snapshot)
|
||||
return load_cfg(config_path)
|
||||
|
||||
|
||||
def _destroy_account_id(config_path: str, cfg: Config) -> str:
|
||||
config_dir = str(Path(config_path).parent)
|
||||
state = read_infra_state(config_dir)
|
||||
account_id = state.get("aws", {}).get("account_id")
|
||||
if account_id:
|
||||
return str(account_id)
|
||||
return identity.account_id(cfg.aws.region, cfg.aws.profile)
|
||||
29
src/commands/utils.py
Normal file
29
src/commands/utils.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
import yaml
|
||||
from rich.console import Console
|
||||
|
||||
from src.config import Config
|
||||
|
||||
CONSOLE = Console()
|
||||
CONFIG_OPT = typer.Option("config.yaml", "--config", "-c", help="Path to config file")
|
||||
|
||||
|
||||
def load_config(path: str = "config.yaml") -> Config:
|
||||
config_path = Path(path)
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Config file not found: {config_path}. Run 'qai-cli init' to create one."
|
||||
)
|
||||
with open(config_path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
return Config.model_validate(data)
|
||||
|
||||
|
||||
def load_cfg(path: str = "config.yaml") -> Config:
|
||||
try:
|
||||
return load_config(path)
|
||||
except FileNotFoundError as e:
|
||||
CONSOLE.print(f"[red]{e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
66
src/config.py
Normal file
66
src/config.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Literal
|
||||
|
||||
from mypy_boto3_s3.literals import BucketLocationConstraintType
|
||||
from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class MlflowMode(str, Enum):
|
||||
disabled = "disabled"
|
||||
create = "create"
|
||||
existing = "existing"
|
||||
|
||||
|
||||
class MlflowServerSize(str, Enum):
|
||||
small = "Small"
|
||||
medium = "Medium"
|
||||
large = "Large"
|
||||
|
||||
|
||||
class AwsConfig(BaseModel):
|
||||
region: BucketLocationConstraintType | Literal["us-east-1"] = "us-east-1"
|
||||
profile: str = "default"
|
||||
|
||||
|
||||
class S3Config(BaseModel):
|
||||
bucket: str = "my-onnx-bucket"
|
||||
data_prefix: str = "data/"
|
||||
model_prefix: str = "models/"
|
||||
|
||||
|
||||
class TrainingConfig(BaseModel):
|
||||
instance_type: TrainingInstanceTypeType = "ml.m5.xlarge"
|
||||
instance_count: int = 1
|
||||
image_uri: str = ""
|
||||
entry_point: str | None = None
|
||||
source_dir: str | None = None
|
||||
hyperparameters: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class SageMakerConfig(BaseModel):
|
||||
role_name: str = "qai-cli-sagemaker-role"
|
||||
training: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||
|
||||
|
||||
class MlflowConfig(BaseModel):
|
||||
mode: MlflowMode = MlflowMode.disabled
|
||||
tracking_server_name: str | None = None
|
||||
artifact_prefix: str = "mlflow/"
|
||||
tracking_server_size: MlflowServerSize = MlflowServerSize.small
|
||||
mlflow_version: str | None = None
|
||||
automatic_model_registration: bool = False
|
||||
weekly_maintenance_window_start: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def require_tracking_server_name(self) -> "MlflowConfig":
|
||||
if self.mode in {MlflowMode.create, MlflowMode.existing} and not self.tracking_server_name:
|
||||
raise ValueError("mlflow.tracking_server_name is required when mlflow.mode is create or existing")
|
||||
return self
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
aws: AwsConfig = Field(default_factory=AwsConfig)
|
||||
s3: S3Config = Field(default_factory=S3Config)
|
||||
sagemaker: SageMakerConfig = Field(default_factory=SageMakerConfig)
|
||||
mlflow: MlflowConfig = Field(default_factory=MlflowConfig)
|
||||
0
src/infra/__init__.py
Normal file
0
src/infra/__init__.py
Normal file
118
src/infra/provisioning.py
Normal file
118
src/infra/provisioning.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import json
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from src.infra.state import state_path, write_infra_state
|
||||
|
||||
STACK_NAME = "QaiCliStack"
|
||||
|
||||
|
||||
def bootstrap(*, profile: str, account_id: str, region: str) -> None:
|
||||
_run(["cdk", "bootstrap", f"aws://{account_id}/{region}", "--profile", profile])
|
||||
|
||||
|
||||
def deploy(
|
||||
*,
|
||||
profile: str,
|
||||
account_id: str,
|
||||
region: str,
|
||||
config_path: str,
|
||||
config_dir: str,
|
||||
config_snapshot: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
outputs_file = state_path(config_dir).with_suffix(".cdk-outputs.json")
|
||||
cmd = _cdk_cmd(
|
||||
"deploy",
|
||||
profile=profile,
|
||||
account_id=account_id,
|
||||
config_path=config_path,
|
||||
delete_bucket_data=False,
|
||||
) + ["--require-approval", "never", "--outputs-file", str(outputs_file)]
|
||||
_run(cmd)
|
||||
|
||||
outputs = _read_outputs(outputs_file)
|
||||
state = {
|
||||
"stack_name": STACK_NAME,
|
||||
"aws": {
|
||||
"account_id": account_id,
|
||||
"region": region,
|
||||
"profile": profile,
|
||||
},
|
||||
"config": config_snapshot,
|
||||
"outputs": outputs,
|
||||
}
|
||||
write_infra_state(config_dir, state)
|
||||
return state
|
||||
|
||||
|
||||
def destroy(
|
||||
*,
|
||||
profile: str,
|
||||
account_id: str,
|
||||
config_path: str,
|
||||
delete_bucket_data: bool,
|
||||
) -> None:
|
||||
if delete_bucket_data:
|
||||
update_cmd = _cdk_cmd(
|
||||
"deploy",
|
||||
profile=profile,
|
||||
account_id=account_id,
|
||||
config_path=config_path,
|
||||
delete_bucket_data=True,
|
||||
) + ["--require-approval", "never"]
|
||||
_run(update_cmd)
|
||||
|
||||
cmd = _cdk_cmd(
|
||||
"destroy",
|
||||
profile=profile,
|
||||
account_id=account_id,
|
||||
config_path=config_path,
|
||||
delete_bucket_data=delete_bucket_data,
|
||||
) + ["--force"]
|
||||
_run(cmd)
|
||||
|
||||
|
||||
def _cdk_cmd(
|
||||
action: str,
|
||||
*,
|
||||
profile: str,
|
||||
account_id: str,
|
||||
config_path: str,
|
||||
delete_bucket_data: bool,
|
||||
) -> list[str]:
|
||||
cmd = [
|
||||
"cdk",
|
||||
action,
|
||||
STACK_NAME,
|
||||
"--app",
|
||||
"python app.py",
|
||||
"--profile",
|
||||
profile,
|
||||
"-c",
|
||||
f"account_id={account_id}",
|
||||
"-c",
|
||||
f"config={config_path}",
|
||||
"-c",
|
||||
f"stack_name={STACK_NAME}",
|
||||
"-c",
|
||||
f"delete_bucket_data={str(delete_bucket_data).lower()}",
|
||||
]
|
||||
return cmd
|
||||
|
||||
|
||||
def _run(cmd: list[str]) -> None:
|
||||
try:
|
||||
subprocess.run(cmd, check=True)
|
||||
except FileNotFoundError as e:
|
||||
raise RuntimeError("CDK CLI not found. Install it with: npm install -g aws-cdk") from e
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(f"CDK command failed with exit code {e.returncode}.") from e
|
||||
|
||||
|
||||
def _read_outputs(path: Path) -> dict[str, str]:
|
||||
if not path.exists():
|
||||
return {}
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
return data.get(STACK_NAME, {})
|
||||
169
src/infra/stack.py
Normal file
169
src/infra/stack.py
Normal file
@@ -0,0 +1,169 @@
|
||||
from typing import Any
|
||||
|
||||
from aws_cdk import CfnOutput, RemovalPolicy, Stack
|
||||
from aws_cdk import aws_iam as iam
|
||||
from aws_cdk import aws_s3 as s3
|
||||
from aws_cdk import aws_sagemaker as sagemaker
|
||||
from constructs import Construct
|
||||
|
||||
from src.config import Config, MlflowMode
|
||||
|
||||
|
||||
class QaiStack(Stack):
|
||||
def __init__(
|
||||
self,
|
||||
scope: Construct,
|
||||
construct_id: str,
|
||||
*,
|
||||
config: Config,
|
||||
delete_bucket_data: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(scope, construct_id, **kwargs)
|
||||
|
||||
removal_policy = RemovalPolicy.DESTROY if delete_bucket_data else RemovalPolicy.RETAIN
|
||||
data_bucket = s3.Bucket(
|
||||
self,
|
||||
"DataBucket",
|
||||
bucket_name=config.s3.bucket,
|
||||
versioned=True,
|
||||
removal_policy=removal_policy,
|
||||
auto_delete_objects=delete_bucket_data,
|
||||
)
|
||||
|
||||
role = iam.CfnRole(
|
||||
self,
|
||||
"SageMakerRole",
|
||||
role_name=config.sagemaker.role_name,
|
||||
assume_role_policy_document=self._sagemaker_trust_policy(),
|
||||
managed_policy_arns=[
|
||||
f"arn:{self.partition}:iam::aws:policy/AmazonSageMakerFullAccess",
|
||||
],
|
||||
)
|
||||
iam.CfnPolicy(
|
||||
self,
|
||||
"SageMakerRoleS3Policy",
|
||||
roles=[role.ref],
|
||||
policy_name="SageMakerRoleS3Policy",
|
||||
policy_document={
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [{
|
||||
"Effect": "Allow",
|
||||
"Action": [
|
||||
"s3:GetObject*",
|
||||
"s3:GetBucket*",
|
||||
"s3:List*",
|
||||
"s3:DeleteObject*",
|
||||
"s3:PutObject",
|
||||
"s3:PutObjectLegalHold",
|
||||
"s3:PutObjectRetention",
|
||||
"s3:PutObjectTagging",
|
||||
"s3:PutObjectVersionTagging",
|
||||
"s3:Abort*",
|
||||
],
|
||||
"Resource": [
|
||||
data_bucket.bucket_arn,
|
||||
f"{data_bucket.bucket_arn}/*",
|
||||
],
|
||||
}],
|
||||
},
|
||||
)
|
||||
|
||||
CfnOutput(self, "DataBucketName", value=data_bucket.bucket_name)
|
||||
CfnOutput(self, "DataBucketArn", value=data_bucket.bucket_arn)
|
||||
CfnOutput(self, "SageMakerRoleArn", value=role.attr_arn)
|
||||
|
||||
if config.mlflow.mode is MlflowMode.create:
|
||||
artifact_prefix = config.mlflow.artifact_prefix.strip("/")
|
||||
artifact_uri = (
|
||||
f"s3://{data_bucket.bucket_name}/{artifact_prefix}/"
|
||||
if artifact_prefix
|
||||
else f"s3://{data_bucket.bucket_name}/"
|
||||
)
|
||||
mlflow_role = iam.CfnRole(
|
||||
self,
|
||||
"MlflowRole",
|
||||
assume_role_policy_document=self._sagemaker_trust_policy(),
|
||||
)
|
||||
s3_statement: dict[str, Any] = {
|
||||
"Effect": "Allow",
|
||||
"Action": [
|
||||
"s3:GetObject*",
|
||||
"s3:GetBucket*",
|
||||
"s3:List*",
|
||||
"s3:DeleteObject*",
|
||||
"s3:PutObject",
|
||||
"s3:PutObjectLegalHold",
|
||||
"s3:PutObjectRetention",
|
||||
"s3:PutObjectTagging",
|
||||
"s3:PutObjectVersionTagging",
|
||||
"s3:Abort*",
|
||||
],
|
||||
"Resource": [
|
||||
data_bucket.bucket_arn,
|
||||
(
|
||||
f"{data_bucket.bucket_arn}/{artifact_prefix}/*"
|
||||
if artifact_prefix
|
||||
else f"{data_bucket.bucket_arn}/*"
|
||||
),
|
||||
],
|
||||
}
|
||||
list_statement: dict[str, Any] = {
|
||||
"Effect": "Allow",
|
||||
"Action": "s3:ListBucket",
|
||||
"Resource": data_bucket.bucket_arn,
|
||||
}
|
||||
if artifact_prefix:
|
||||
list_statement["Condition"] = {"StringLike": {"s3:prefix": [f"{artifact_prefix}/*"]}}
|
||||
iam.CfnPolicy(
|
||||
self,
|
||||
"MlflowRolePolicy",
|
||||
roles=[mlflow_role.ref],
|
||||
policy_name="MlflowRolePolicy",
|
||||
policy_document={
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
s3_statement,
|
||||
list_statement,
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": [
|
||||
"sagemaker:AddTags",
|
||||
"sagemaker:CreateModelPackageGroup",
|
||||
"sagemaker:CreateModelPackage",
|
||||
"sagemaker:UpdateModelPackage",
|
||||
"sagemaker:DescribeModelPackageGroup",
|
||||
],
|
||||
"Resource": "*",
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
tracking_server = sagemaker.CfnMlflowTrackingServer(
|
||||
self,
|
||||
"MlflowTrackingServer",
|
||||
artifact_store_uri=artifact_uri,
|
||||
role_arn=mlflow_role.attr_arn,
|
||||
tracking_server_name=config.mlflow.tracking_server_name or "",
|
||||
automatic_model_registration=config.mlflow.automatic_model_registration,
|
||||
mlflow_version=config.mlflow.mlflow_version,
|
||||
tracking_server_size=config.mlflow.tracking_server_size.value,
|
||||
weekly_maintenance_window_start=config.mlflow.weekly_maintenance_window_start,
|
||||
)
|
||||
|
||||
CfnOutput(self, "MlflowTrackingServerName", value=config.mlflow.tracking_server_name or "")
|
||||
CfnOutput(self, "MlflowTrackingServerArn", value=tracking_server.attr_tracking_server_arn)
|
||||
CfnOutput(self, "MlflowArtifactUri", value=artifact_uri)
|
||||
CfnOutput(self, "MlflowRoleArn", value=mlflow_role.attr_arn)
|
||||
|
||||
@staticmethod
|
||||
def _sagemaker_trust_policy() -> dict[str, Any]:
|
||||
return {
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [{
|
||||
"Effect": "Allow",
|
||||
"Principal": {"Service": "sagemaker.amazonaws.com"},
|
||||
"Action": "sts:AssumeRole",
|
||||
}],
|
||||
}
|
||||
22
src/infra/state.py
Normal file
22
src/infra/state.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
INFRA_STATE_FILE = ".qai-cli-infra.json"
|
||||
|
||||
|
||||
def state_path(config_dir: str) -> Path:
|
||||
return Path(config_dir) / INFRA_STATE_FILE
|
||||
|
||||
|
||||
def read_infra_state(config_dir: str) -> dict[str, Any]:
|
||||
path = state_path(config_dir)
|
||||
if not path.exists():
|
||||
return {}
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_infra_state(config_dir: str, state: dict[str, Any]) -> None:
|
||||
with open(state_path(config_dir), "w") as f:
|
||||
json.dump(state, f, indent=2)
|
||||
39
src/main.py
Normal file
39
src/main.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
import yaml
|
||||
from rich.console import Console
|
||||
|
||||
from src.commands import infra
|
||||
from src.config import Config
|
||||
|
||||
app = typer.Typer(
|
||||
help="qai-cli: End-to-end model managment for Qualcomm AI Hub.",
|
||||
no_args_is_help=True,
|
||||
)
|
||||
app.add_typer(infra.app, name="infra")
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
@app.command()
|
||||
def init(
|
||||
output: str = typer.Option("config.yaml", help="Destination path for the config file"),
|
||||
force: bool = typer.Option(False, "--force", "-f", help="Overwrite an existing config file"),
|
||||
) -> None:
|
||||
"""Write a starter config.yaml to the current directory."""
|
||||
dest = Path(output)
|
||||
if dest.exists() and not force:
|
||||
console.print(f"[yellow]{dest} already exists.[/yellow] Use --force to overwrite.")
|
||||
raise typer.Exit(1)
|
||||
|
||||
config = Config()
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(dest, "w") as f:
|
||||
yaml.safe_dump(config.model_dump(mode="json"), f, sort_keys=False)
|
||||
|
||||
console.print(f"[green]✓[/green] Config written to [bold]{dest}[/bold]")
|
||||
console.print(
|
||||
"Edit it (especially [cyan]s3.bucket[/cyan] and [cyan]sagemaker.training.image_uri[/cyan]) "
|
||||
"before running other commands."
|
||||
)
|
||||
Reference in New Issue
Block a user