make sure resources are set up in isolated namespaces (#1)
Reviewed-on: #1
This commit was merged in pull request #1.
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -218,7 +218,7 @@ __marimo__/
|
||||
.streamlit/secrets.toml
|
||||
|
||||
.venv/
|
||||
config.yaml
|
||||
config*.yaml
|
||||
cdk.out/
|
||||
.qc-cli*.json
|
||||
examples/*/data/
|
||||
|
||||
18
README.md
18
README.md
@@ -30,7 +30,7 @@ qc-cli --help
|
||||
# 1. Create config.yaml in the current directory
|
||||
qc-cli init
|
||||
|
||||
# 2. Edit config.yaml — at minimum set s3.bucket and sagemaker.training.image_uri
|
||||
# 2. Edit config.yaml — at minimum set sagemaker.training.image_uri
|
||||
|
||||
# 3. Provision AWS infrastructure (S3 bucket + SageMaker IAM role).
|
||||
# This is the step that requires the AWS CDK CLI.
|
||||
@@ -47,15 +47,17 @@ qc-cli train status
|
||||
`qc-cli init` writes a `config.yaml` in the current directory. The fields you must fill in before using the tool:
|
||||
|
||||
```yaml
|
||||
infra:
|
||||
stack_name: qc-cli-mlops-1a2b3c4d5e6f
|
||||
|
||||
aws:
|
||||
region: us-east-1
|
||||
profile: default # AWS CLI profile name
|
||||
|
||||
s3:
|
||||
bucket: your-unique-bucket-name
|
||||
bucket: qc-cli-mlops-1a2b3c4d5e6f-data
|
||||
|
||||
sagemaker:
|
||||
role_name: qc-cli-sagemaker-role
|
||||
training:
|
||||
image_uri: "" # ECR URI for your training container
|
||||
instance_type: ml.m5.xlarge
|
||||
@@ -65,6 +67,10 @@ sagemaker:
|
||||
hyperparameters: {}
|
||||
```
|
||||
|
||||
`qc-cli init` generates the `infra.stack_name` and `s3.bucket` namespace once and writes it to `config.yaml`. Keep these values stable for a deployment; changing them points the CLI at different infrastructure.
|
||||
|
||||
The CLI isolates both application resources and CDK bootstrap resources. The application CloudFormation stack uses `infra.stack_name`, the S3 bucket uses the same generated namespace because bucket names are globally unique, and the SageMaker IAM role uses a CloudFormation-generated physical name. CDK bootstrap resources are derived internally from `infra.stack_name`, including a bootstrap stack named `<stack_name>-bootstrap` and a matching non-default CDK asset bucket qualifier. `qc-cli infra destroy` removes the application stack but leaves the CDK bootstrap stack in place; the command prints the retained bootstrap stack name.
|
||||
|
||||
`hyperparameters` is a flat map of values passed to the training container. Valid keys depend on the selected training image and entry point.
|
||||
|
||||
To provision an MLflow tracking server, set:
|
||||
@@ -105,6 +111,12 @@ qc-cli infra destroy --yes Destroy stack without confirmation
|
||||
qc-cli infra destroy --delete-bucket-data Destroy stack and delete S3 data
|
||||
```
|
||||
|
||||
`--cloudformation-execution-policy` is a one-time CDK bootstrap option, not a `config.yaml` setting. Pass it on `infra setup` when you need the CDK bootstrap CloudFormation execution role to use a policy other than the default `AdministratorAccess`:
|
||||
|
||||
```bash
|
||||
qc-cli infra setup --cloudformation-execution-policy arn:aws:iam::aws:policy/PowerUserAccess
|
||||
```
|
||||
|
||||
### `upload`
|
||||
|
||||
```
|
||||
|
||||
4
app.py
4
app.py
@@ -8,17 +8,19 @@ from src.infra.stack import QCStack
|
||||
app = cdk.App()
|
||||
|
||||
config_path = app.node.try_get_context("config") or "config.yaml"
|
||||
stack_name = app.node.try_get_context("stack_name") or "MLOpsStack"
|
||||
account_id = app.node.try_get_context("account_id") or os.getenv("CDK_DEFAULT_ACCOUNT")
|
||||
delete_bucket_data = str(app.node.try_get_context("delete_bucket_data") or "false").lower() == "true"
|
||||
|
||||
cfg = load_config(config_path)
|
||||
stack_name = app.node.try_get_context("stack_name") or cfg.infra.stack_name
|
||||
bootstrap_qualifier = app.node.try_get_context("bootstrap_qualifier") or cfg.infra.effective_bootstrap_qualifier
|
||||
|
||||
QCStack(
|
||||
app,
|
||||
stack_name,
|
||||
config=cfg,
|
||||
delete_bucket_data=delete_bucket_data,
|
||||
synthesizer=cdk.DefaultStackSynthesizer(qualifier=bootstrap_qualifier),
|
||||
env=cdk.Environment(
|
||||
account=account_id,
|
||||
region=cfg.aws.region,
|
||||
|
||||
@@ -13,7 +13,6 @@ s3:
|
||||
bucket: your-bucket-name
|
||||
|
||||
sagemaker:
|
||||
role_name: <role-name>
|
||||
training:
|
||||
image_uri: 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.6-cpu-py312-ubuntu22.04-sagemaker-v1
|
||||
instance_type: ml.m4.xlarge
|
||||
|
||||
@@ -3,13 +3,11 @@ from typing import Any
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from src.infra.provisioning import STACK_NAME
|
||||
|
||||
|
||||
def stack_status(region: str, profile: str) -> dict[str, Any] | None:
|
||||
def stack_status(region: str, profile: str, stack_name: str) -> dict[str, Any] | None:
|
||||
client = boto3.Session(profile_name=profile, region_name=region).client("cloudformation")
|
||||
try:
|
||||
stack = client.describe_stacks(StackName=STACK_NAME)["Stacks"][0]
|
||||
stack = client.describe_stacks(StackName=stack_name)["Stacks"][0]
|
||||
except ClientError as e:
|
||||
message = e.response.get("Error", {}).get("Message", "")
|
||||
if "does not exist" in message:
|
||||
|
||||
@@ -51,6 +51,8 @@ def setup(
|
||||
profile=cfg.aws.profile,
|
||||
account_id=account_id,
|
||||
region=cfg.aws.region,
|
||||
bootstrap_qualifier=cfg.infra.effective_bootstrap_qualifier,
|
||||
toolkit_stack_name=cfg.infra.effective_toolkit_stack_name,
|
||||
cloudformation_execution_policy=cloudformation_execution_policy,
|
||||
)
|
||||
with CONSOLE.status("Running cdk deploy..."):
|
||||
@@ -58,6 +60,9 @@ def setup(
|
||||
profile=cfg.aws.profile,
|
||||
account_id=account_id,
|
||||
region=cfg.aws.region,
|
||||
stack_name=cfg.infra.stack_name,
|
||||
bootstrap_qualifier=cfg.infra.effective_bootstrap_qualifier,
|
||||
toolkit_stack_name=cfg.infra.effective_toolkit_stack_name,
|
||||
config_path=config,
|
||||
config_dir=str(Path(config).parent),
|
||||
config_snapshot=cfg.model_dump(mode="json"),
|
||||
@@ -82,7 +87,7 @@ def setup(
|
||||
def status(config: str = CONFIG_OPT) -> None:
|
||||
"""Show current infrastructure status."""
|
||||
cfg = load_cfg(config)
|
||||
stack = cloudformation.stack_status(cfg.aws.region, cfg.aws.profile)
|
||||
stack = cloudformation.stack_status(cfg.aws.region, cfg.aws.profile, cfg.infra.stack_name)
|
||||
|
||||
table = Table(title="Infrastructure Status")
|
||||
table.add_column("Resource", style="cyan")
|
||||
@@ -91,7 +96,7 @@ def status(config: str = CONFIG_OPT) -> None:
|
||||
table.add_column("ARN / URI")
|
||||
|
||||
if not stack:
|
||||
table.add_row("CDK Stack", provisioning.STACK_NAME, "[red]missing[/red]", "-")
|
||||
table.add_row("CDK Stack", cfg.infra.stack_name, "[red]missing[/red]", "-")
|
||||
table.add_row("S3 Bucket", cfg.s3.bucket, "[red]unknown[/red]", "-")
|
||||
table.add_row("IAM Role", cfg.sagemaker.role_name, "[red]unknown[/red]", "-")
|
||||
if cfg.mlflow.mode is not MlflowMode.disabled:
|
||||
@@ -114,7 +119,7 @@ def status(config: str = CONFIG_OPT) -> None:
|
||||
)
|
||||
table.add_row(
|
||||
"IAM Role",
|
||||
cfg.sagemaker.role_name,
|
||||
_role_name(cfg.sagemaker.role_name, outputs.get("SageMakerRoleArn", "")),
|
||||
"[green]managed[/green]",
|
||||
outputs.get("SageMakerRoleArn", "-"),
|
||||
)
|
||||
@@ -156,10 +161,13 @@ def destroy(
|
||||
) -> None:
|
||||
"""Destroy the CDK stack."""
|
||||
cfg = _destroy_config(config)
|
||||
stack_name = _destroy_stack_name(config, cfg)
|
||||
bootstrap_qualifier = _destroy_bootstrap_qualifier(config, cfg)
|
||||
toolkit_stack_name = _destroy_toolkit_stack_name(config, cfg)
|
||||
|
||||
if not yes and not delete_bucket_data:
|
||||
typer.confirm(
|
||||
f"Destroy CDK stack '{provisioning.STACK_NAME}' while retaining S3 bucket data?",
|
||||
f"Destroy CDK stack '{stack_name}' while retaining S3 bucket data?",
|
||||
abort=True,
|
||||
)
|
||||
|
||||
@@ -172,13 +180,17 @@ def destroy(
|
||||
provisioning.destroy(
|
||||
profile=cfg.aws.profile,
|
||||
account_id=account_id,
|
||||
stack_name=stack_name,
|
||||
bootstrap_qualifier=bootstrap_qualifier,
|
||||
toolkit_stack_name=toolkit_stack_name,
|
||||
config_path=str(snapshot_path),
|
||||
delete_bucket_data=delete_bucket_data,
|
||||
)
|
||||
except RuntimeError as e:
|
||||
CONSOLE.print(f"[red]{e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
CONSOLE.print(f"[green]✓[/green] Destroyed stack: {provisioning.STACK_NAME}")
|
||||
CONSOLE.print(f"[green]✓[/green] Destroyed stack: {stack_name}")
|
||||
CONSOLE.print(f"[yellow]CDK bootstrap stack retained: {toolkit_stack_name}[/yellow]")
|
||||
|
||||
|
||||
def _destroy_config(config_path: str) -> Config:
|
||||
@@ -190,6 +202,13 @@ def _destroy_config(config_path: str) -> Config:
|
||||
return load_cfg(config_path)
|
||||
|
||||
|
||||
def _role_name(configured_name: str, role_arn: str) -> str:
|
||||
if configured_name:
|
||||
return configured_name
|
||||
if role_arn:
|
||||
return role_arn.rsplit("/", 1)[-1]
|
||||
return "-"
|
||||
|
||||
def _destroy_account_id(config_path: str, cfg: Config) -> str:
|
||||
config_dir = str(Path(config_path).parent)
|
||||
state = read_infra_state(config_dir)
|
||||
@@ -197,3 +216,30 @@ def _destroy_account_id(config_path: str, cfg: Config) -> str:
|
||||
if account_id:
|
||||
return str(account_id)
|
||||
return identity.account_id(cfg.aws.region, cfg.aws.profile)
|
||||
|
||||
|
||||
def _destroy_stack_name(config_path: str, cfg: Config) -> str:
|
||||
config_dir = str(Path(config_path).parent)
|
||||
state = read_infra_state(config_dir)
|
||||
stack_name = state.get("stack_name")
|
||||
if stack_name:
|
||||
return str(stack_name)
|
||||
return cfg.infra.stack_name
|
||||
|
||||
|
||||
def _destroy_bootstrap_qualifier(config_path: str, cfg: Config) -> str:
|
||||
config_dir = str(Path(config_path).parent)
|
||||
state = read_infra_state(config_dir)
|
||||
bootstrap_qualifier = state.get("bootstrap_qualifier")
|
||||
if bootstrap_qualifier:
|
||||
return str(bootstrap_qualifier)
|
||||
return cfg.infra.effective_bootstrap_qualifier
|
||||
|
||||
|
||||
def _destroy_toolkit_stack_name(config_path: str, cfg: Config) -> str:
|
||||
config_dir = str(Path(config_path).parent)
|
||||
state = read_infra_state(config_dir)
|
||||
toolkit_stack_name = state.get("toolkit_stack_name")
|
||||
if toolkit_stack_name:
|
||||
return str(toolkit_stack_name)
|
||||
return cfg.infra.effective_toolkit_stack_name
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
from rich.table import Table
|
||||
@@ -7,6 +8,8 @@ from src import state as state_ops
|
||||
from src.aws import iam
|
||||
from src.aws import sagemaker as sm_ops
|
||||
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
||||
from src.config import Config
|
||||
from src.infra.state import read_infra_state
|
||||
|
||||
app = typer.Typer(help="Manage SageMaker training jobs")
|
||||
|
||||
@@ -20,10 +23,22 @@ _STATUS_COLOR = {
|
||||
|
||||
|
||||
def _config_dir(config_path: str) -> str:
|
||||
from pathlib import Path
|
||||
return str(Path(config_path).parent)
|
||||
|
||||
|
||||
def _sagemaker_role_arn(config_path: str, cfg: Config) -> str:
|
||||
state = read_infra_state(_config_dir(config_path))
|
||||
role_arn = state.get("outputs", {}).get("SageMakerRoleArn")
|
||||
if role_arn:
|
||||
return str(role_arn)
|
||||
if cfg.sagemaker.role_name:
|
||||
role_arn = iam.get_role_arn(cfg.aws.profile, cfg.sagemaker.role_name)
|
||||
if role_arn:
|
||||
return role_arn
|
||||
raise RuntimeError(f"IAM role '{cfg.sagemaker.role_name}' not found. Run 'qc-cli infra setup' first.")
|
||||
raise RuntimeError("SageMaker role not found in infra state. Run 'qc-cli infra setup' first.")
|
||||
|
||||
|
||||
@app.command()
|
||||
def start(config: str = CONFIG_OPT) -> None:
|
||||
"""Submit a SageMaker training job."""
|
||||
@@ -37,9 +52,10 @@ def start(config: str = CONFIG_OPT) -> None:
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
role_arn = iam.get_role_arn(cfg.aws.profile, cfg.sagemaker.role_name)
|
||||
if not role_arn:
|
||||
CONSOLE.print(f"[red]IAM role '{cfg.sagemaker.role_name}' not found. Run 'qc-cli infra setup' first.[/red]")
|
||||
try:
|
||||
role_arn = _sagemaker_role_arn(config, cfg)
|
||||
except RuntimeError as e:
|
||||
CONSOLE.print(f"[red]{e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
job_name = f"qc-cli-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
@@ -32,6 +33,33 @@ class AwsConfig(BaseModel):
|
||||
return {"profile_name": self.profile, "region_name": self.region}
|
||||
|
||||
|
||||
DEFAULT_BOOTSTRAP_QUALIFIER = "hnb659fds"
|
||||
GENERATED_STACK_PREFIX = "qc-cli-mlops-"
|
||||
|
||||
|
||||
class InfraConfig(BaseModel):
|
||||
stack_name: str
|
||||
|
||||
@property
|
||||
def effective_bootstrap_qualifier(self) -> str:
|
||||
sanitized = re.sub(r"[^a-z0-9]", "", self.stack_name.lower())
|
||||
if not sanitized:
|
||||
return DEFAULT_BOOTSTRAP_QUALIFIER
|
||||
if self.stack_name.startswith(GENERATED_STACK_PREFIX):
|
||||
suffix = re.sub(r"[^a-z0-9]", "", self.stack_name.removeprefix(GENERATED_STACK_PREFIX).lower())
|
||||
if suffix:
|
||||
return f"q{suffix}"[:10]
|
||||
return f"q{sanitized}"[:10]
|
||||
|
||||
@property
|
||||
def effective_toolkit_stack_name(self) -> str:
|
||||
if self.stack_name.startswith(GENERATED_STACK_PREFIX):
|
||||
suffix = re.sub(r"[^A-Za-z0-9-]", "", self.stack_name.removeprefix(GENERATED_STACK_PREFIX))
|
||||
if suffix:
|
||||
return f"{self.stack_name}-bootstrap"
|
||||
return f"{self.stack_name}-bootstrap"
|
||||
|
||||
|
||||
class S3Config(BaseModel):
|
||||
bucket: str = "my-qc-mlops-bucket"
|
||||
data_prefix: str = "data/"
|
||||
@@ -48,7 +76,7 @@ class TrainingConfig(BaseModel):
|
||||
|
||||
|
||||
class SageMakerConfig(BaseModel):
|
||||
role_name: str = "qc-cli-sagemaker-role"
|
||||
role_name: str = ""
|
||||
training: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||
|
||||
|
||||
@@ -69,6 +97,7 @@ class MlflowConfig(BaseModel):
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
infra: InfraConfig
|
||||
aws: AwsConfig = Field(default_factory=AwsConfig)
|
||||
s3: S3Config = Field(default_factory=S3Config)
|
||||
sagemaker: SageMakerConfig = Field(default_factory=SageMakerConfig)
|
||||
|
||||
@@ -5,17 +5,27 @@ from typing import Any
|
||||
|
||||
from src.infra.state import state_path, write_infra_state
|
||||
|
||||
STACK_NAME = "MLOpsStack"
|
||||
|
||||
|
||||
def bootstrap(
|
||||
*,
|
||||
profile: str,
|
||||
account_id: str,
|
||||
region: str,
|
||||
bootstrap_qualifier: str,
|
||||
toolkit_stack_name: str,
|
||||
cloudformation_execution_policy: str | None = None,
|
||||
) -> None:
|
||||
cmd = ["cdk", "bootstrap", f"aws://{account_id}/{region}", "--profile", profile]
|
||||
cmd = [
|
||||
"cdk",
|
||||
"bootstrap",
|
||||
f"aws://{account_id}/{region}",
|
||||
"--profile",
|
||||
profile,
|
||||
"--qualifier",
|
||||
bootstrap_qualifier,
|
||||
"--toolkit-stack-name",
|
||||
toolkit_stack_name,
|
||||
]
|
||||
if cloudformation_execution_policy:
|
||||
cmd.extend(["--cloudformation-execution-policies", cloudformation_execution_policy])
|
||||
_run(cmd)
|
||||
@@ -26,6 +36,9 @@ def deploy(
|
||||
profile: str,
|
||||
account_id: str,
|
||||
region: str,
|
||||
stack_name: str,
|
||||
bootstrap_qualifier: str,
|
||||
toolkit_stack_name: str,
|
||||
config_path: str,
|
||||
config_dir: str,
|
||||
config_snapshot: dict[str, Any],
|
||||
@@ -35,19 +48,24 @@ def deploy(
|
||||
"deploy",
|
||||
profile=profile,
|
||||
account_id=account_id,
|
||||
stack_name=stack_name,
|
||||
bootstrap_qualifier=bootstrap_qualifier,
|
||||
toolkit_stack_name=toolkit_stack_name,
|
||||
config_path=config_path,
|
||||
delete_bucket_data=False,
|
||||
) + ["--require-approval", "never", "--outputs-file", str(outputs_file)]
|
||||
_run(cmd)
|
||||
|
||||
outputs = _read_outputs(outputs_file)
|
||||
outputs = _read_outputs(outputs_file, stack_name)
|
||||
state = {
|
||||
"stack_name": STACK_NAME,
|
||||
"stack_name": stack_name,
|
||||
"aws": {
|
||||
"account_id": account_id,
|
||||
"region": region,
|
||||
"profile": profile,
|
||||
},
|
||||
"bootstrap_qualifier": bootstrap_qualifier,
|
||||
"toolkit_stack_name": toolkit_stack_name,
|
||||
"config": config_snapshot,
|
||||
"outputs": outputs,
|
||||
}
|
||||
@@ -59,6 +77,9 @@ def destroy(
|
||||
*,
|
||||
profile: str,
|
||||
account_id: str,
|
||||
stack_name: str,
|
||||
bootstrap_qualifier: str,
|
||||
toolkit_stack_name: str,
|
||||
config_path: str,
|
||||
delete_bucket_data: bool,
|
||||
) -> None:
|
||||
@@ -67,6 +88,9 @@ def destroy(
|
||||
"deploy",
|
||||
profile=profile,
|
||||
account_id=account_id,
|
||||
stack_name=stack_name,
|
||||
bootstrap_qualifier=bootstrap_qualifier,
|
||||
toolkit_stack_name=toolkit_stack_name,
|
||||
config_path=config_path,
|
||||
delete_bucket_data=True,
|
||||
) + ["--require-approval", "never"]
|
||||
@@ -76,6 +100,9 @@ def destroy(
|
||||
"destroy",
|
||||
profile=profile,
|
||||
account_id=account_id,
|
||||
stack_name=stack_name,
|
||||
bootstrap_qualifier=bootstrap_qualifier,
|
||||
toolkit_stack_name=toolkit_stack_name,
|
||||
config_path=config_path,
|
||||
delete_bucket_data=delete_bucket_data,
|
||||
) + ["--force"]
|
||||
@@ -87,26 +114,35 @@ def _cdk_cmd(
|
||||
*,
|
||||
profile: str,
|
||||
account_id: str,
|
||||
stack_name: str,
|
||||
bootstrap_qualifier: str,
|
||||
toolkit_stack_name: str,
|
||||
config_path: str,
|
||||
delete_bucket_data: bool,
|
||||
) -> list[str]:
|
||||
cmd = [
|
||||
"cdk",
|
||||
action,
|
||||
STACK_NAME,
|
||||
stack_name,
|
||||
"--app",
|
||||
"python app.py",
|
||||
"--profile",
|
||||
profile,
|
||||
]
|
||||
if action == "deploy":
|
||||
cmd.extend(["--toolkit-stack-name", toolkit_stack_name])
|
||||
cmd.extend([
|
||||
"-c",
|
||||
f"account_id={account_id}",
|
||||
"-c",
|
||||
f"config={config_path}",
|
||||
"-c",
|
||||
f"stack_name={STACK_NAME}",
|
||||
f"stack_name={stack_name}",
|
||||
"-c",
|
||||
f"bootstrap_qualifier={bootstrap_qualifier}",
|
||||
"-c",
|
||||
f"delete_bucket_data={str(delete_bucket_data).lower()}",
|
||||
]
|
||||
])
|
||||
return cmd
|
||||
|
||||
|
||||
@@ -119,9 +155,9 @@ def _run(cmd: list[str]) -> None:
|
||||
raise RuntimeError(f"CDK command failed with exit code {e.returncode}.") from e
|
||||
|
||||
|
||||
def _read_outputs(path: Path) -> dict[str, str]:
|
||||
def _read_outputs(path: Path, stack_name: str) -> dict[str, str]:
|
||||
if not path.exists():
|
||||
return {}
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
return data.get(STACK_NAME, {})
|
||||
return data.get(stack_name, {})
|
||||
|
||||
@@ -34,7 +34,7 @@ class QCStack(Stack):
|
||||
role = iam.CfnRole(
|
||||
self,
|
||||
"SageMakerRole",
|
||||
role_name=config.sagemaker.role_name,
|
||||
role_name=config.sagemaker.role_name or None,
|
||||
assume_role_policy_document=self._sagemaker_trust_policy(),
|
||||
managed_policy_arns=[
|
||||
f"arn:{self.partition}:iam::aws:policy/AmazonSageMakerFullAccess",
|
||||
|
||||
20
src/main.py
20
src/main.py
@@ -1,3 +1,4 @@
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
@@ -8,7 +9,7 @@ from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn
|
||||
from src.aws import s3 as s3_ops
|
||||
from src.commands import infra, train
|
||||
from src.commands.utils import CONFIG_OPT, load_cfg
|
||||
from src.config import Config
|
||||
from src.config import GENERATED_STACK_PREFIX, Config, InfraConfig, S3Config
|
||||
|
||||
app = typer.Typer(
|
||||
help="qc-cli: End-to-end model managment for Qualcomm AI Hub.",
|
||||
@@ -31,18 +32,27 @@ def init(
|
||||
console.print(f"[yellow]{dest} already exists.[/yellow] Use --force to overwrite.")
|
||||
raise typer.Exit(1)
|
||||
|
||||
config = Config()
|
||||
config = _new_isolated_config()
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
config_data = config.model_dump(mode="json")
|
||||
config_data["sagemaker"].pop("role_name", None)
|
||||
with open(dest, "w") as f:
|
||||
yaml.safe_dump(config.model_dump(mode="json"), f, sort_keys=False)
|
||||
yaml.safe_dump(config_data, f, sort_keys=False)
|
||||
|
||||
console.print(f"[green]✓[/green] Config written to [bold]{dest}[/bold]")
|
||||
console.print(
|
||||
"Edit it (especially [cyan]s3.bucket[/cyan] and [cyan]sagemaker.training.image_uri[/cyan]) "
|
||||
"before running other commands."
|
||||
"Edit [cyan]sagemaker.training.image_uri[/cyan] before running training commands."
|
||||
)
|
||||
|
||||
|
||||
def _new_isolated_config() -> Config:
|
||||
suffix = secrets.token_hex(6)
|
||||
namespace = f"{GENERATED_STACK_PREFIX}{suffix}"
|
||||
config = Config(infra=InfraConfig(stack_name=namespace))
|
||||
config.s3 = S3Config(bucket=f"{namespace}-data")
|
||||
return config
|
||||
|
||||
|
||||
@app.command()
|
||||
def upload(
|
||||
path: Path = typer.Argument(..., help="Local file or directory to upload"),
|
||||
|
||||
Reference in New Issue
Block a user