17 Commits

Author SHA1 Message Date
b56b77330c clean 2026-06-12 14:34:45 -04:00
a1ffbb77c5 include-metrics-from-training (#6)
Reviewed-on: #6
2026-06-12 18:23:25 +00:00
522ddc74e2 New example and updated ai-hun upload order (#4)
Co-authored-by: samirodr <sami.rodrigue@slalom.com>
Reviewed-on: #4
2026-06-12 14:34:44 +00:00
samirodr
5360a482fc update 2026-06-08 14:59:44 -04:00
samirodr
6a560a8610 match 2026-06-08 14:54:13 -04:00
d244150d98 move mlflow to its own command 2026-06-05 11:47:38 -04:00
d7c7158464 clean main file 2026-06-05 11:25:04 -04:00
6bc25dc183 restructure config to use Device class directly
Also include device validation
2026-06-04 17:28:17 -04:00
samirodr
71a95aa3a7 update description 2026-06-03 17:13:00 -04:00
a3f3060e13 ai-hub (#3)
Reviewed-on: #3
2026-06-03 21:06:06 +00:00
e9ada2612f Mlflow implementation (#2)
Reviewed-on: #2
2026-06-02 19:04:23 +00:00
6ac9702dc5 make sure resources are set up in isolated namespaces (#1)
Reviewed-on: #1
2026-05-27 12:51:26 +00:00
0e728cc193 command to start sagemaker training
include sample training
2026-05-25 16:48:31 -04:00
62ffe163e8 enable s3 upload 2026-05-20 16:42:07 -04:00
samirodr
cfc04b473f update 2026-05-20 15:21:45 -04:00
717257dd75 update naming 2026-05-20 14:16:12 -04:00
75255b37d0 rename 2026-05-20 14:13:21 -04:00
38 changed files with 5348 additions and 102 deletions

223
.gitignore vendored
View File

@@ -1,5 +1,224 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[codz]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py.cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
# Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
# poetry.lock
# poetry.toml
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
# pdm.lock
# pdm.toml
.pdm-python
.pdm-build/
# pixi
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
# pixi.lock
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
# in the .venv directory. It is recommended not to include this directory in version control.
.pixi
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# Redis
*.rdb
*.aof
*.pid
# RabbitMQ
mnesia/
rabbitmq/
rabbitmq-data/
# ActiveMQ
activemq-data/
# SageMath parsed files
*.sage.py
# Environments
.env
.envrc
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
# .idea/
# Abstra
# Abstra is an AI-powered process automation framework.
# Ignore directories containing user credentials, local state, and settings.
# Learn more at https://abstra.io/docs
.abstra/
# Visual Studio Code
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
# and can be added to the global gitignore or merged into this file. However, if you prefer,
# you could uncomment the following to ignore the entire vscode folder
# .vscode/
# Temporary file for partial code execution
tempCodeRunnerFile.py
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc
# Marimo
marimo/_static/
marimo/_lsp/
__marimo__/
# Streamlit
.streamlit/secrets.toml
.venv/
config.yaml
config*.yaml
cdk.out/
.qai-cli-infra*
.qc-cli*.json
examples/*/data/

220
README.md
View File

@@ -1,66 +1,102 @@
# qai-cli
# qc-cli
A CLI for the Qualcomm model MLOps pipeline — browse and download models from Qualcomm AI Hub, fine-tune them on custom datasets using SageMaker, validate inference, and prepare artifacts for Qualcomm hardware deployment.
A CLI for Qualcomm's MLOps pipeline — browse and download models from Qualcomm AI Hub, fine-tune them on custom datasets using SageMaker, validate inference, and prepare artifacts for Qualcomm hardware deployment.
## Requirements
- Python 3.13+
- [uv](https://docs.astral.sh/uv/getting-started/installation/)
- AWS account with credentials configured (`aws configure`) when using `qai-cli infra`
- AWS CDK CLI (`npm install -g aws-cdk`) when using `qai-cli infra setup` or `qai-cli infra destroy`
- AWS account with credentials configured (`aws configure`) when using `qc-cli infra`
- AWS CDK CLI (`npm install -g aws-cdk`) when using `qc-cli infra setup` or `qc-cli infra destroy`
## Installation
```bash
git clone <repo>
cd qai-cli
cd qc-cli
uv sync
```
Run commands with `uv run qai-cli <command>` or activate the venv first:
Run commands with `uv run qc-cli <command>` or activate the venv first:
```bash
source .venv/bin/activate
qai-cli --help
qc-cli --help
```
## Quick start
```bash
# 1. Create config.yaml in the current directory
qai-cli init
qc-cli init
# 2. Edit config.yaml — at minimum set s3.bucket and sagemaker.role_name
# 2. Edit config.yaml — at minimum set sagemaker.training.image_uri
# 3. Provision AWS infrastructure (S3 bucket + SageMaker IAM role).
# This is the step that requires the AWS CDK CLI.
qai-cli infra setup
qc-cli infra setup
# 4. Upload training data, then submit a SageMaker training job.
qc-cli upload ./my-dataset
qc-cli train start
qc-cli train status
```
## Configuration
`qai-cli init` writes a `config.yaml` in the current directory. The fields you must fill in before using the tool:
`qc-cli init` writes a `config.yaml` in the current directory. The fields you must fill in before using the tool:
```yaml
infra:
stack_name: qc-cli-mlops-1a2b3c4d5e6f
aws:
region: us-east-1
profile: default # AWS CLI profile name
s3:
bucket: your-unique-bucket-name
bucket: qc-cli-mlops-1a2b3c4d5e6f-data
sagemaker:
role_name: qai-cli-sagemaker-role
training:
image_uri: "" # ECR URI for your training container
instance_type: ml.m5.xlarge
instance_count: 1
entry_point: null # Optional: script inside source_dir
source_dir: null # Optional: local dir packaged and uploaded automatically
hyperparameters: {}
aihub:
device:
name: Samsung Galaxy S25 (Family)
target_runtime: tflite
input_specs: {} # Required before running qc-cli ai-hub commands
job_name: null # Optional prefix for AI Hub Workbench jobs
model_name: null # Optional name for uploaded local ONNX models
compile_options: null
profile_options: null
quantize_options: null
output_dir: build/qai-hub
```
`qc-cli init` generates the `infra.stack_name` and `s3.bucket` namespace once and writes it to `config.yaml`. Keep these values stable for a deployment; changing them points the CLI at different infrastructure.
The CLI isolates both application resources and CDK bootstrap resources. The application CloudFormation stack uses `infra.stack_name`, the S3 bucket uses the same generated namespace because bucket names are globally unique, and the SageMaker IAM role uses a CloudFormation-generated physical name. CDK bootstrap resources are derived internally from `infra.stack_name`, including a bootstrap stack named `<stack_name>-bootstrap` and a matching non-default CDK asset bucket qualifier. `qc-cli infra destroy` removes the application stack but leaves the CDK bootstrap stack in place; the command prints the retained bootstrap stack name.
`hyperparameters` is a flat map of values passed to the training container. Valid keys depend on the selected training image and entry point.
To provision an MLflow tracking server, set:
```yaml
mlflow:
mode: create
tracking_server_name: your-tracking-server-name
experiment_name: qc-cli-training
registered_model_name: qc-cli-model
register_trained_models: true
```
In `create` mode, the CLI manages the tracking server name from `infra.stack_name`; you do not need to set `tracking_server_name`.
To use an existing MLflow tracking server, set:
```yaml
@@ -69,28 +105,163 @@ mlflow:
tracking_server_name: your-tracking-server-name
```
When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. Training metrics can be upload with `train start --upload-metrics` or `mlflow upload-metrics`.
To open the managed SageMaker MLflow UI, request a fresh presigned URL:
```bash
qc-cli mlflow open --config config.yaml
```
This opens a browser to a fresh presigned URL. It works for `mode: create` and for `mode: existing` when the existing server is managed by Amazon SageMaker. In `create` mode, the command uses the CLI-managed tracking server name. In `existing` mode, it uses `mlflow.tracking_server_name`. If the existing MLflow server is external to SageMaker, open it with that server's own URL instead.
## Commands
### `init`
```
qai-cli init Write config.yaml
qai-cli init --output <path> Write config to a custom path
qai-cli init --force Overwrite an existing config file
qc-cli init Write config.yaml
qc-cli init --output <path> Write config to a custom path
qc-cli init --force Overwrite an existing config file
```
### `infra`
```
qai-cli infra setup Deploy the CDK stack
qai-cli infra setup --no-bootstrap Deploy without running CDK bootstrap
qai-cli infra setup --cloudformation-execution-policy <arn> Set CDK bootstrap execution policy ARN
qai-cli infra status Show CDK stack/resource status
qai-cli infra destroy Destroy stack, retaining S3 data
qai-cli infra destroy --yes Destroy stack without confirmation
qai-cli infra destroy --delete-bucket-data Destroy stack and delete S3 data
qc-cli infra setup Deploy the CDK stack
qc-cli infra setup --no-bootstrap Deploy without running CDK bootstrap
qc-cli infra setup --cloudformation-execution-policy <arn> Set CDK bootstrap execution policy ARN
qc-cli infra status Show CDK stack/resource status
qc-cli infra destroy Destroy stack, retaining S3 data
qc-cli infra destroy --yes Destroy stack without confirmation
qc-cli infra destroy --delete-bucket-data Destroy stack and delete S3 data
```
`--cloudformation-execution-policy` is a one-time CDK bootstrap option, not a `config.yaml` setting. Pass it on `infra setup` when you need the CDK bootstrap CloudFormation execution role to use a policy other than the default `AdministratorAccess`:
```bash
qc-cli infra setup --cloudformation-execution-policy arn:aws:iam::aws:policy/PowerUserAccess
```
### `mlflow`
```
qc-cli mlflow open Open a presigned MLflow UI URL
qc-cli mlflow upload-metrics [job-name] Upload completed training metrics
```
`mlflow upload-metrics` defaults to the last submitted training job. It creates or recovers the job's MLflow run, imports `training_metrics.json` from the SageMaker model artifact, and records successful upload in `.qc-cli.json`. Use `--force` to upload the metrics again.
### `upload`
```
qc-cli upload <file> Upload a single file to S3
qc-cli upload <dir> Upload all files in a directory tree to S3
qc-cli upload <file> --s3-key <key> Upload a file to a custom S3 key
```
Uploads use `s3.bucket` and `s3.data_prefix` from `config.yaml`. File uploads default to `s3://<bucket>/<data_prefix>/<filename>`. Directory uploads are recursive, preserve paths relative to the uploaded directory, and place files under `s3://<bucket>/<data_prefix>/`.
### `train`
```
qc-cli train start Submit a SageMaker training job
qc-cli train start --upload-metrics Submit, wait, and upload metrics
qc-cli train status [job-name] Show job status; defaults to the last submitted job
qc-cli train list List recent training jobs
qc-cli train list --limit 3 Show a custom number of recent jobs
```
`train start` uses `s3://<bucket>/<data_prefix>/` as the training channel and writes outputs under `s3://<bucket>/<model_prefix>/`. If `sagemaker.training.source_dir` is set, the CLI packages that directory, uploads it beside the job output prefix, and passes `sagemaker_program`/`sagemaker_submit_directory` to the SageMaker container.
`train start --upload-metrics` checks SageMaker every 30 seconds by default, then uploads metrics after completion. Use `--poll-interval <seconds>` to choose another positive interval. Stopping the local command does not stop the SageMaker job.
The expected output artifact is SageMakers `model.tar.gz`, normally containing the trained model file your container writes to `/opt/ml/model`.
### `ai-hub`
```
qc-cli ai-hub upload <calibration.npz|calibration-dir> <inputs.npz|inputs.npy>
qc-cli ai-hub upload <calibration> <inputs> --from-step validate
qc-cli ai-hub optimize [--onnx-path PATH] [--model-s3-uri URI] [--from-job NAME]
qc-cli ai-hub quantize <calibration.npz|calibration-dir> [--model-id ID] [--onnx-path PATH] [--model-s3-uri URI] [--from-job NAME]
qc-cli ai-hub compile [--model-id ID] [--onnx-path PATH] [--model-s3-uri URI] [--from-job NAME]
qc-cli ai-hub validate <inputs.npz|inputs.npy> [--model-id ID] [--input-name NAME]
qc-cli ai-hub profile [--model-id ID]
qc-cli ai-hub download [--model-id ID] [--output PATH]
```
`ai-hub upload` optimizes to ONNX, quantizes, validates, and profiles. When `aihub.target_runtime` is not `onnx`, it also compiles the quantized model to that deployment runtime. The initial ONNX optimization gives external models Workbench provenance and applies compiler optimization passes before quantization.
Resume behavior:
```text
--from-step optimize Run optimize, quantize, optional final compile, validate, and profile.
--from-step quantize Quantize the last optimized ONNX, then optionally compile, validate, and profile.
--from-step compile Skip optimize and quantize; finalize the last quantized model for the target runtime.
--from-step validate Skip optimize, quantize, and compile; validate the last compiled model.
--from-step profile Skip optimize, quantize, compile, and validate; profile the last compiled model.
```
When a step runs in the current command, `upload` passes its returned model ID directly to the next step. When a step is skipped, the next step resolves the needed model ID from `.qc-cli.json`. This avoids re-running earlier AI Hub jobs when you only need to continue from a later step.
`ai-hub optimize` compiles an external model with `--target_runtime onnx`. `ai-hub quantize` uses an explicit `--model-id`, the last optimized ONNX model, or an explicit/local model source in that order. `ai-hub compile` resolves model sources in this order: `--model-id`, explicit source options, last quantized model, then the last training job. For `target_runtime: onnx`, upload treats the quantized ONNX as the final model and skips a redundant second compile. `ai-hub download` remains separate because downloading is outside the Workbench processing loop.
AI Hub authentication currently uses the local `qai-hub` SDK configuration. A planned follow-up is to support AWS Systems Manager Parameter Store `SecureString` for team-managed tokens, where `config.yaml` stores only a parameter name such as `/qc-cli/aihub/token`, AWS KMS encrypts the token at rest, and the CLI retrieves it at runtime with `ssm:GetParameter` plus `kms:Decrypt` permissions.
## Model lifecycle
The CLI uses neutral experiment naming for trained artifacts and reserves release terminology for an explicit promotion step.
Current behavior:
1. `qc-cli train start` submits a SageMaker training job.
2. `qc-cli train status` reads and displays SageMaker status only; it does not contact MLflow.
3. `qc-cli train start --upload-metrics` polls every 30 seconds by default, then uploads per-epoch metrics after completion.
4. `qc-cli mlflow upload-metrics [job-name]` uploads or retries metrics for an existing completed job.
5. The metrics upload workflow finalizes the MLflow run and, when `mlflow.register_trained_models` is enabled, registers the SageMaker `model.tar.gz` as a new MLflow model version with:
- `qc_cli.stage=experiment`
- `qc_cli.artifact_kind=trained_source`
- `qc_cli.source=sagemaker`
6. The MLflow alias `experiment-latest` points at the most recently registered experiment version.
7. AI Hub upload commands create deployable derived artifacts from a trained-source experiment or local ONNX model.
Training scripts can include a `training_metrics.json` file in the SageMaker model directory. When present, the explicit metrics upload command logs its ordered metrics to the associated MLflow run using each epoch as the MLflow step and stores the JSON as a run artifact:
```json
{
"schema_version": 1,
"steps": [
{"step": 0, "metrics": {"val.precision": 0.72, "val.recall": 0.68}}
],
"summary": {"summary.best_epoch": 0}
}
```
Metric names must be non-empty strings, values must be finite numbers, and steps must be non-negative, unique, and strictly increasing. If the file is missing, the command uploads the final metrics reported by SageMaker and continues model registration without per-epoch history. A malformed metrics artifact still fails the upload command without affecting the trained model or model registration.
Future release aliases such as `v1` or `production` can point at a selected deployable artifact.
Example future metadata:
```text
qc-cli-model version 12
qc_cli.stage=experiment
qc_cli.artifact_kind=trained_source
qc_cli.source=sagemaker
qc-cli-model-aihub version 3
qc_cli.stage=ai_hub_compiled
qc_cli.artifact_kind=deployable
qc_cli.parent_registered_model_name=qc-cli-model
qc_cli.parent_model_version=12
qc_cli.runtime=tflite
qc_cli.quantization=int8
qc_cli.target_device=Samsung Galaxy S25
```
In that flow, `experiment-latest` remains a training convenience alias. Release selection is a separate promotion decision based on the derived artifact, not on the experiment name.
## AWS permissions required
The IAM user or role running the CLI needs:
@@ -101,6 +272,7 @@ The IAM user or role running the CLI needs:
| CreateRole, GetRole, DeleteRole, AttachRolePolicy, DetachRolePolicy | IAM |
| CreateStack, UpdateStack, DeleteStack, DescribeStacks, DescribeStackEvents | CloudFormation |
| GetCallerIdentity | STS |
| CreateTrainingJob, DescribeTrainingJob, ListTrainingJobs | SageMaker AI |
| CreateMlflowTrackingServer, DescribeMlflowTrackingServer, DeleteMlflowTrackingServer | SageMaker AI, when `mlflow.mode` is `create` or `existing` |
`AdministratorAccess` covers all of the above.

8
app.py
View File

@@ -3,22 +3,24 @@ import os
import aws_cdk as cdk
from src.commands.utils import load_config
from src.infra.stack import QaiStack
from src.infra.stack import QCStack
app = cdk.App()
config_path = app.node.try_get_context("config") or "config.yaml"
stack_name = app.node.try_get_context("stack_name") or "QaiCliStack"
account_id = app.node.try_get_context("account_id") or os.getenv("CDK_DEFAULT_ACCOUNT")
delete_bucket_data = str(app.node.try_get_context("delete_bucket_data") or "false").lower() == "true"
cfg = load_config(config_path)
stack_name = app.node.try_get_context("stack_name") or cfg.infra.stack_name
bootstrap_qualifier = app.node.try_get_context("bootstrap_qualifier") or cfg.infra.effective_bootstrap_qualifier
QaiStack(
QCStack(
app,
stack_name,
config=cfg,
delete_bucket_data=delete_bucket_data,
synthesizer=cdk.DefaultStackSynthesizer(qualifier=bootstrap_qualifier),
env=cdk.Environment(
account=account_id,
region=cfg.aws.region,

View File

@@ -0,0 +1,285 @@
# YOLO26 Electric Meter Detection Example
This example trains a YOLO26 object detection model on the Roboflow Universe electric meter dataset using the existing `qc-cli` SageMaker training flow.
The workflow is intentionally command driven. Run each step yourself so you can inspect the dataset, update `config.yaml`, and decide when to submit the SageMaker job.
Dataset:
```text
https://universe.roboflow.com/kemals-workspace-kbc8l/electric-meter-detection-o4tfi/dataset/1
```
## Prerequisites
- Install or sync the project dependencies: `uv sync`
- The virtual environment is activated.
- AWS credentials configured for the profile in `config.yaml`
- Infrastructure already deployed with `qc-cli infra setup`
## 1. Download The Dataset
Register or sign in to Roboflow, then open the dataset page:
```text
https://universe.roboflow.com/kemals-workspace-kbc8l/electric-meter-detection-o4tfi/dataset/1
```
Download the dataset in YOLOv26 format from the Roboflow UI, then extract the downloaded archive into:
```text
examples/meter-detection/data/electric-meter-detection
```
The `data.yaml` file should be directly under that folder:
```text
examples/meter-detection/data/electric-meter-detection/data.yaml
```
Do not move `data.yaml` into the `train/` split folder.
After extracting, confirm the dataset has a YOLO data file and image splits:
```bash
find examples/meter-detection/data/electric-meter-detection -maxdepth 2 -type d | sort
find examples/meter-detection/data/electric-meter-detection -name data.yaml -print
```
Open `examples/meter-detection/data/electric-meter-detection/data.yaml` and make sure the split paths are relative to that folder:
```yaml
path: .
train: train/images
val: valid/images
test: test/images
```
If your downloaded dataset does not include a `test/` folder, remove the `test:` line.
The expected layout is similar to:
```text
examples/meter-detection/data/electric-meter-detection/
data.yaml
train/
valid/
test/
```
## 2. Configure SageMaker Training
Update `config.yaml` so the training section points at this example's source directory:
```yaml
sagemaker:
training:
image_uri: 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.6-cpu-py312-ubuntu22.04-sagemaker-v1
instance_type: ml.g4dn.xlarge
instance_count: 1
source_dir: examples/meter-detection/source
entry_point: train.py
hyperparameters:
model: yolo26n.pt
epochs: 25
imgsz: 640
batch: 16
workers: 2
```
Use `yolo26n.pt` for a lightweight first YOLO26 run. If those weights are unavailable in the installed Ultralytics package, use `yolo11n.pt` as the established fallback:
```yaml
model: yolo11n.pt
```
The `source/requirements.txt` file is installed by the SageMaker PyTorch container before running `train.py`.
For a CPU smoke test, use a CPU instance and reduce the workload:
```yaml
sagemaker:
training:
image_uri: 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.6-cpu-py312-ubuntu22.04-sagemaker-v1
instance_type: ml.m4.xlarge
instance_count: 1
source_dir: examples/meter-detection/source
entry_point: train.py
hyperparameters:
model: yolo26n.pt
epochs: 1
imgsz: 320
batch: 4
workers: 2
```
## 3. Check Infrastructure
Confirm the CLI can see the configured SageMaker role and S3 bucket:
```bash
qc-cli infra status
```
## 4. Upload The Dataset
Upload the downloaded Roboflow dataset to the `s3.data_prefix` configured in `config.yaml`:
```bash
qc-cli upload examples/meter-detection/data/electric-meter-detection
```
Directory uploads preserve paths relative to the uploaded directory, so SageMaker receives the dataset root with `data.yaml` plus the split directories.
In SageMaker, this uploaded dataset root is mounted at `/opt/ml/input/data/train`. That `train` path is the SageMaker channel name, not the YOLO `train/` split folder.
## 5. Start Training
Submit the SageMaker training job:
```bash
qc-cli train start
```
The command prints the submitted SageMaker job name. Check progress with:
```bash
qc-cli train status
```
Or pass the job name explicitly:
```bash
qc-cli train status qc-cli-YYYYMMDD-HHMMSS
```
To submit the job, wait for completion, and automatically import metrics and register the model, run:
```bash
qc-cli train start --upload-metrics
```
The default polling interval is 30 seconds. It can be changed with `--poll-interval <seconds>`.
The metrics can be also submitted using:
```bash
qc-cli mlflow upload-metrics
```
## SageMaker Outputs
When the job completes, SageMaker packages the files written under `/opt/ml/model` into `model.tar.gz`.
This example writes:
```text
best.pt
model.onnx
metrics.json
training_metrics.json
```
The archive is stored under the configured `s3.model_prefix`.
The `mlflow upload-metrics` command imports `training_metrics.json`, which provides per-epoch training and validation
losses, precision, recall, mAP@0.50, mAP@0.50:0.95, and learning rates. For object detection, mAP and precision/recall
are more meaningful than classification accuracy when assessing model quality.
## 6. Configure Qualcomm AI Hub
Authenticate with Qualcomm AI Hub:
```bash
qai-hub configure --api_token
```
Add AI Hub settings to `config.yaml`. The input name and image size must match the ONNX model exported by this example:
```yaml
aihub:
device:
name: Dragonwing IQ-9075 EVK
target_runtime: onnx
input_specs:
images: [[1, 3, 640, 640], float32]
job_name: meter-detection
model_name: meter-detection
output_dir: build/qai-hub/meter-detection
```
The ONNX graph is the source of truth. The export normally uses the same value as `sagemaker.training.hyperparameters.imgsz`, but changing `config.yaml` after training does not resize an existing model. For example, a model exported with `imgsz: 320` requires `images: [[1, 3, 320, 320], float32]`.
## 7. Prepare AI Hub Inputs
Generate calibration samples and a validation input from the downloaded dataset:
```bash
uv run python examples/meter-detection/prepare_aihub_inputs.py --image-size 640
```
This writes:
```text
examples/meter-detection/data/aihub_calibration/*.npy
examples/meter-detection/data/inputs.npz
```
The script applies the preprocessing expected by the exported YOLO model: aspect-ratio-preserving letterboxing, RGB channel order, channel-first layout, and pixel values normalized to `[0, 1]`.
## 8. Upload To Qualcomm AI Hub
Use the SageMaker job name printed by `qc-cli train start`:
```bash
qc-cli ai-hub upload \
examples/meter-detection/data/aihub_calibration \
examples/meter-detection/data/inputs.npz \
--from-job qc-cli-YYYYMMDD-HHMMSS
```
The command downloads the job's `model.tar.gz`, finds `model.onnx`, and runs the following AI Hub workflow:
1. Compile the external ONNX to a Workbench-optimized ONNX model.
2. Quantize the optimized ONNX model.
3. Compile the quantized model when the configured deployment runtime is not `onnx`.
4. Validate and profile the final model.
The training example sanitizes the Ultralytics ONNX export before saving `model.onnx`. This removes graph input or output names, such as `output0`, that are duplicated in the ONNX `value_info` metadata and rejected by AI Hub.
For a model already downloaded by a failed upload attempt, sanitize the extracted ONNX file and retry using the local model. Replace the job name in both paths:
```bash
uv run --with onnx python examples/meter-detection/source/sanitize_onnx.py \
build/qai-hub/meter-detection/qc-cli-YYYYMMDD-HHMMSS/source/extracted/model.onnx \
--output build/qai-hub/meter-detection/model.aihub.onnx
qc-cli ai-hub upload \
examples/meter-detection/data/aihub_calibration \
examples/meter-detection/data/inputs.npz \
--onnx-path build/qai-hub/meter-detection/model.aihub.onnx
```
Download the compiled artifact after the workflow completes:
```bash
qc-cli ai-hub download --output build/qai-hub/meter-detection/model.tflite
```
## Training Hyperparameters
Values under `sagemaker.training.hyperparameters` are passed to `source/train.py` as command-line arguments.
| Name | Type | Default | Description |
|---|---:|---:|---|
| `model` | string | `yolo26n.pt` | Ultralytics model weights or model YAML. |
| `epochs` | int | `25` | Number of training epochs. |
| `imgsz` | int | `640` | Square training image size. |
| `batch` | int | `16` | Images per training batch. |
| `workers` | int | `2` | DataLoader worker count. |
| `patience` | int | `20` | Early stopping patience. |
| `device` | string | auto | Optional Ultralytics device value such as `0` or `cpu`. |
| `data-yaml` | string | auto | Optional path to `data.yaml`; normally discovered from the uploaded dataset root. |
| `dataset-dir` | string | `SM_CHANNEL_TRAIN` | Uploaded dataset root mounted by SageMaker. |
Do not set `dataset-dir` or `model-dir` in normal SageMaker runs. SageMaker sets those automatically through `SM_CHANNEL_TRAIN` and `SM_MODEL_DIR`.

View File

@@ -0,0 +1,92 @@
#!/usr/bin/env python3
"""Prepare Qualcomm AI Hub calibration and validation inputs for the meter detector."""
from __future__ import annotations
import argparse
from pathlib import Path
import numpy as np
from PIL import Image
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png"}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--dataset-dir",
type=Path,
default=Path("examples/meter-detection/data/electric-meter-detection"),
help="Root of the extracted Roboflow dataset.",
)
parser.add_argument(
"--calibration-dir",
type=Path,
default=Path("examples/meter-detection/data/aihub_calibration"),
help="Directory where .npy calibration samples will be written.",
)
parser.add_argument(
"--input-file",
type=Path,
default=Path("examples/meter-detection/data/inputs.npz"),
help="Validation .npz input file for qc-cli ai-hub validate.",
)
parser.add_argument("--input-name", default="images", help="ONNX input name.")
parser.add_argument("--image-size", type=int, default=640, help="Square image size used for ONNX export.")
parser.add_argument("--samples", type=int, default=16, help="Number of calibration samples to write.")
return parser.parse_args()
def preprocess_image(path: Path, image_size: int) -> np.ndarray:
"""Apply Ultralytics-style letterboxing and produce an NCHW float32 tensor."""
with Image.open(path) as source:
image = source.convert("RGB")
scale = min(image_size / image.width, image_size / image.height)
resized_width = round(image.width * scale)
resized_height = round(image.height * scale)
image = image.resize((resized_width, resized_height), Image.Resampling.BILINEAR)
canvas = Image.new("RGB", (image_size, image_size), (114, 114, 114))
left = round((image_size - resized_width) / 2 - 0.1)
top = round((image_size - resized_height) / 2 - 0.1)
canvas.paste(image, (left, top))
array = np.asarray(canvas, dtype=np.float32) / 255.0
return np.transpose(array, (2, 0, 1))[None, ...].astype(np.float32)
def main() -> None:
args = parse_args()
if args.image_size < 1:
raise SystemExit("--image-size must be at least 1")
if args.samples < 1:
raise SystemExit("--samples must be at least 1")
images = sorted(
path
for path in args.dataset_dir.rglob("*")
if path.is_file() and path.suffix.lower() in IMAGE_EXTENSIONS and path.parent.name == "images"
)
if not images:
raise SystemExit(f"No images found under {args.dataset_dir}")
args.calibration_dir.mkdir(parents=True, exist_ok=True)
args.input_file.parent.mkdir(parents=True, exist_ok=True)
for stale_sample in args.calibration_dir.glob("sample_*.npy"):
stale_sample.unlink()
prepared: list[np.ndarray] = []
for index, image_path in enumerate(images[: args.samples]):
sample = preprocess_image(image_path, args.image_size)
np.save(args.calibration_dir / f"sample_{index:03d}.npy", sample)
prepared.append(sample)
np.savez(args.input_file, **{args.input_name: prepared[0]}) # pyright: ignore[reportArgumentType]
print(f"Wrote {len(prepared)} calibration samples to {args.calibration_dir}")
print(f"Wrote validation input to {args.input_file}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,3 @@
ultralytics>=8.3.0
pyyaml>=6.0.3
onnx>=1.16.0

View File

@@ -0,0 +1,38 @@
#!/usr/bin/env python3
"""Remove ONNX value_info entries that duplicate graph inputs or outputs."""
from __future__ import annotations
import argparse
from pathlib import Path
import onnx # type: ignore[reportMissingImports]
def sanitize_onnx(path: Path, output_path: Path | None = None) -> Path:
model = onnx.load(path)
io_names = {value.name for value in (*model.graph.input, *model.graph.output)}
retained_value_info = [value for value in model.graph.value_info if value.name not in io_names]
destination = output_path or path
if len(retained_value_info) != len(model.graph.value_info):
del model.graph.value_info[:]
model.graph.value_info.extend(retained_value_info)
destination.parent.mkdir(parents=True, exist_ok=True)
onnx.save(model, destination)
return destination
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("onnx_path", type=Path)
parser.add_argument("--output", type=Path)
args = parser.parse_args()
written = sanitize_onnx(args.onnx_path, args.output)
print(f"Saved sanitized ONNX model to {written}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,128 @@
#!/usr/bin/env python3
"""SageMaker entry point for YOLO electric meter detection training."""
from __future__ import annotations
import argparse
import json
import os
import shutil
from pathlib import Path
from typing import Any
import yaml
from sanitize_onnx import sanitize_onnx
from training_metrics import write_training_metrics
from ultralytics import YOLO # type: ignore[reportMissingImports]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="yolo26n.pt")
parser.add_argument("--epochs", type=int, default=25)
parser.add_argument("--imgsz", type=int, default=640)
parser.add_argument("--batch", type=int, default=16)
parser.add_argument("--workers", type=int, default=2)
parser.add_argument("--patience", type=int, default=20)
parser.add_argument("--device", default=None)
parser.add_argument("--data-yaml", default=None)
parser.add_argument("--dataset-dir", default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
parser.add_argument("--train-dir", dest="dataset_dir", help=argparse.SUPPRESS)
parser.add_argument("--model-dir", default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
return parser.parse_args()
def find_data_yaml(dataset_dir: Path, explicit_path: str | None) -> Path:
if explicit_path:
data_yaml = Path(explicit_path)
if data_yaml.is_file():
return data_yaml
raise FileNotFoundError(f"Configured data.yaml does not exist: {data_yaml}")
matches = sorted(dataset_dir.rglob("data.yaml"))
if not matches:
raise FileNotFoundError(f"Could not find data.yaml under {dataset_dir}")
if len(matches) > 1:
print(f"Found multiple data.yaml files; using {matches[0]}")
return matches[0]
def prepare_data_yaml(data_yaml: Path) -> Path:
"""Write a SageMaker-local data file rooted at the uploaded dataset."""
dataset_root = data_yaml.parent
data = yaml.safe_load(data_yaml.read_text(encoding="utf-8"))
if not isinstance(data, dict):
raise ValueError(f"Expected a mapping in {data_yaml}")
normalized = dict(data)
normalized["path"] = str(dataset_root)
if "val" not in normalized and "valid" in normalized:
normalized["val"] = normalized.pop("valid")
prepared_path = dataset_root / "data.sagemaker.yaml"
prepared_path.write_text(yaml.safe_dump(normalized, sort_keys=False), encoding="utf-8")
print(f"Prepared dataset config: {prepared_path}")
return prepared_path
def copy_if_exists(source: Path, destination: Path) -> None:
if source.exists():
shutil.copy2(source, destination)
print(f"Saved {destination}")
def main() -> None:
args = parse_args()
dataset_dir = Path(args.dataset_dir)
model_dir = Path(args.model_dir)
model_dir.mkdir(parents=True, exist_ok=True)
data_yaml = prepare_data_yaml(find_data_yaml(dataset_dir, args.data_yaml))
model = YOLO(args.model)
train_kwargs: dict[str, Any] = {
"data": str(data_yaml),
"epochs": args.epochs,
"imgsz": args.imgsz,
"batch": args.batch,
"workers": args.workers,
"patience": args.patience,
"project": str(model_dir / "runs"),
"name": "train",
"exist_ok": True,
}
if args.device:
train_kwargs["device"] = args.device
results = model.train(**train_kwargs)
save_dir = Path(results.save_dir)
best_pt = save_dir / "weights" / "best.pt"
last_pt = save_dir / "weights" / "last.pt"
trained_weights = best_pt if best_pt.exists() else last_pt
if not trained_weights.exists():
raise FileNotFoundError(f"Could not find trained weights in {save_dir / 'weights'}")
write_training_metrics(save_dir / "results.csv", model_dir / "training_metrics.json")
copy_if_exists(trained_weights, model_dir / "best.pt")
trained_model = YOLO(str(trained_weights))
onnx_path = Path(trained_model.export(format="onnx", imgsz=args.imgsz))
saved_onnx_path = sanitize_onnx(onnx_path, model_dir / "model.onnx")
print(f"Saved {saved_onnx_path}")
metrics = {
"model": args.model,
"epochs": args.epochs,
"imgsz": args.imgsz,
"batch": args.batch,
"workers": args.workers,
"patience": args.patience,
"data_yaml": str(data_yaml),
"weights": str(trained_weights),
"onnx": str(saved_onnx_path),
}
(model_dir / "metrics.json").write_text(json.dumps(metrics, indent=2), encoding="utf-8")
print(f"Saved model artifacts to {model_dir}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,82 @@
import csv
import json
import math
import re
from pathlib import Path
from typing import Any
METRIC_NAMES = {
"metrics/precision(B)": "val.precision",
"metrics/recall(B)": "val.recall",
"metrics/mAP50(B)": "val.map50",
"metrics/mAP50-95(B)": "val.map50_95",
"train/box_loss": "train.box_loss",
"train/cls_loss": "train.cls_loss",
"train/dfl_loss": "train.dfl_loss",
"val/box_loss": "val.box_loss",
"val/cls_loss": "val.cls_loss",
"val/dfl_loss": "val.dfl_loss",
"time": "train.elapsed_seconds",
}
def write_training_metrics(results_csv: Path, destination: Path) -> None:
steps = _read_metric_steps(results_csv)
summary = _build_summary(steps)
payload = {
"schema_version": 1,
"steps": steps,
"summary": summary,
}
destination.write_text(json.dumps(payload, indent=2), encoding="utf-8")
print(f"Saved {destination}")
def _read_metric_steps(results_csv: Path) -> list[dict[str, Any]]:
if not results_csv.is_file():
raise FileNotFoundError(f"Could not find Ultralytics metrics history: {results_csv}")
steps: list[dict[str, Any]] = []
with results_csv.open(newline="", encoding="utf-8") as csv_file:
for row_index, raw_row in enumerate(csv.DictReader(csv_file)):
row = {str(key).strip(): value for key, value in raw_row.items()}
raw_epoch = row.pop("epoch", row_index)
step = int(float(raw_epoch))
metrics: dict[str, float] = {}
for source_name, raw_value in row.items():
if raw_value is None or not raw_value.strip():
continue
try:
value = float(raw_value)
except ValueError:
continue
if math.isfinite(value):
metrics[METRIC_NAMES.get(source_name, _normalize_metric_name(source_name))] = value
steps.append({"step": step, "metrics": metrics})
return steps
def _build_summary(steps: list[dict[str, Any]]) -> dict[str, float]:
if not steps:
return {}
summary: dict[str, float] = {}
final_step = steps[-1]
summary["summary.final_epoch"] = float(final_step["step"])
for name, value in final_step["metrics"].items():
summary[f"summary.final.{name}"] = value
scored_steps = [step for step in steps if "val.map50_95" in step["metrics"]]
if scored_steps:
best_step = max(scored_steps, key=lambda step: step["metrics"]["val.map50_95"])
summary["summary.best_epoch"] = float(best_step["step"])
summary["summary.best_val.map50_95"] = best_step["metrics"]["val.map50_95"]
if "val.map50" in best_step["metrics"]:
summary["summary.best_val.map50"] = best_step["metrics"]["val.map50"]
return summary
def _normalize_metric_name(name: str) -> str:
normalized = name.replace("/", ".")
normalized = re.sub(r"[^A-Za-z0-9_.-]+", "_", normalized)
return normalized.strip("._") or "unnamed"

View File

@@ -3,21 +3,25 @@ requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "qai-cli"
name = "qc-cli"
version = "0.1.0"
description = "CLI for SageMaker ONNX training and Qualcomm AI Hub optimization"
description = "CLI for training and deploying models for Qualcomm AI Hub"
requires-python = ">=3.13"
dependencies = [
"aws-cdk-lib>=2.180.0",
"typer==0.25.0",
"boto3>=1.34,<1.42",
"constructs>=10.0.0",
"mlflow>=3.0",
"numpy>=1.26",
"pydantic>=2.13.3",
"pyyaml>=6.0.3",
"qai-hub>=0.49.0",
"sagemaker-mlflow>=0.4.0",
]
[project.scripts]
qai-cli = "src.main:app"
qc-cli = "src.main:app"
[tool.hatch.build.targets.wheel]
packages = ["src"]

View File

@@ -3,13 +3,11 @@ from typing import Any
import boto3
from botocore.exceptions import ClientError
from src.infra.provisioning import STACK_NAME
def stack_status(region: str, profile: str) -> dict[str, Any] | None:
def stack_status(region: str, profile: str, stack_name: str) -> dict[str, Any] | None:
client = boto3.Session(profile_name=profile, region_name=region).client("cloudformation")
try:
stack = client.describe_stacks(StackName=STACK_NAME)["Stacks"][0]
stack = client.describe_stacks(StackName=stack_name)["Stacks"][0]
except ClientError as e:
message = e.response.get("Error", {}).get("Message", "")
if "does not exist" in message:

17
src/aws/iam.py Normal file
View File

@@ -0,0 +1,17 @@
import boto3
from botocore.exceptions import ClientError
from mypy_boto3_iam import IAMClient
def _client(profile: str) -> IAMClient:
return boto3.Session(profile_name=profile).client("iam")
def get_role_arn(profile: str, role_name: str) -> str | None:
client = _client(profile)
try:
return client.get_role(RoleName=role_name)["Role"]["Arn"]
except ClientError as e:
if e.response.get("Error", {}).get("Code") == "NoSuchEntity":
return None
raise

View File

@@ -1,3 +1,6 @@
import os
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, cast
import boto3
@@ -17,3 +20,55 @@ def describe_tracking_server(region: str, profile: str, name: str) -> dict[str,
):
return None
raise
def get_tracking_server_arn(region: str, profile: str, name: str) -> str:
server = describe_tracking_server(region, profile, name)
if not server:
raise ValueError(f"MLflow tracking server not found: {name}")
arn = server.get("TrackingServerArn")
if not arn:
raise ValueError(f"MLflow tracking server has no ARN: {name}")
return str(arn)
def create_presigned_tracking_server_url(region: str, profile: str, name: str) -> str:
client = boto3.Session(profile_name=profile, region_name=region).client("sagemaker")
response = client.create_presigned_mlflow_tracking_server_url(TrackingServerName=name)
return str(response["AuthorizedUrl"])
@contextmanager
def tracking_auth_env(profile: str, region: str) -> Generator[None]:
credentials = boto3.Session(profile_name=profile, region_name=region).get_credentials()
if credentials is None:
raise RuntimeError(f"AWS credentials could not be resolved for profile '{profile}'.")
frozen_credentials = credentials.get_frozen_credentials()
if not frozen_credentials.access_key or not frozen_credentials.secret_key:
raise RuntimeError(f"AWS credentials are incomplete for profile '{profile}'.")
env_updates = {
"AWS_PROFILE": profile,
"AWS_DEFAULT_REGION": region,
"AWS_REGION": region,
"AWS_ACCESS_KEY_ID": frozen_credentials.access_key,
"AWS_SECRET_ACCESS_KEY": frozen_credentials.secret_key,
}
if frozen_credentials.token:
env_updates["AWS_SESSION_TOKEN"] = frozen_credentials.token
restore_keys = set(env_updates) | {"AWS_SESSION_TOKEN"}
previous_env = {key: os.environ.get(key) for key in restore_keys}
try:
os.environ.update(env_updates)
if not frozen_credentials.token:
os.environ.pop("AWS_SESSION_TOKEN", None)
yield
finally:
for key, value in previous_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value

69
src/aws/s3.py Normal file
View File

@@ -0,0 +1,69 @@
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import boto3
from mypy_boto3_s3 import S3Client
def _client(region: str, profile: str) -> S3Client:
return boto3.Session(profile_name=profile, region_name=region).client("s3")
def upload_file(
region: str,
profile: str,
bucket: str,
local_path: str,
s3_key: str,
) -> str:
_client(region, profile).upload_file(local_path, bucket, s3_key)
return f"s3://{bucket}/{s3_key}"
def download_file(
region: str,
profile: str,
s3_uri: str,
local_path: str,
) -> str:
if not s3_uri.startswith("s3://"):
raise ValueError(f"Expected S3 URI, got: {s3_uri}")
bucket_key = s3_uri.removeprefix("s3://")
bucket, _, key = bucket_key.partition("/")
if not bucket or not key:
raise ValueError(f"Expected S3 URI with bucket and key, got: {s3_uri}")
dest = Path(local_path)
dest.parent.mkdir(parents=True, exist_ok=True)
_client(region, profile).download_file(bucket, key, str(dest))
return str(dest)
def upload_dir(
region: str,
profile: str,
bucket: str,
local_dir: str,
s3_prefix: str,
on_progress: Callable[[], None] | None = None,
) -> int:
root = Path(local_dir)
files = [file for file in root.rglob("*") if file.is_file()]
if not files:
return 0
client = _client(region, profile)
prefix = s3_prefix.rstrip("/")
def upload_one(file_path: Path) -> None:
key = f"{prefix}/{file_path.relative_to(root)}"
client.upload_file(str(file_path), bucket, key)
if on_progress:
on_progress()
with ThreadPoolExecutor(max_workers=10) as pool:
futures = [pool.submit(upload_one, file) for file in files]
for future in as_completed(futures):
future.result()
return len(files)

143
src/aws/sagemaker.py Normal file
View File

@@ -0,0 +1,143 @@
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
import boto3
from mypy_boto3_sagemaker import SageMakerClient
from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
from mypy_boto3_sagemaker.type_defs import (
CreateTrainingJobRequestTypeDef,
ResourceConfigTypeDef,
TrainingJobSummaryTypeDef,
)
from src.config import Boto3SessionKwargs
@dataclass(frozen=True)
class TrainingJobRequest:
role_arn: str
image_uri: str
instance_type: TrainingInstanceTypeType
instance_count: int
s3_train_uri: str
s3_output_path: str
job_name: str
hyperparameters: dict[str, Any] = field(default_factory=dict)
entry_point: str | None = None
source_dir: str | None = None
@dataclass(frozen=True)
class TrainingJobStatus:
name: str
status: str
created: datetime | None
modified: datetime | None
model_artifacts: str | None
failure_reason: str | None
raw: dict[str, Any] = field(default_factory=dict)
def _sm(session: Boto3SessionKwargs) -> SageMakerClient:
return boto3.Session(**session).client("sagemaker")
def _upload_source_dir(
session: Boto3SessionKwargs,
source_dir: str,
s3_output_path: str,
job_name: str,
) -> str:
import io
import tarfile
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w:gz") as tar:
tar.add(source_dir, arcname=".")
buf.seek(0)
without_scheme = s3_output_path.removeprefix("s3://")
bucket, _, prefix = without_scheme.partition("/")
key = f"{prefix.rstrip('/')}/{job_name}/source/sourcedir.tar.gz".lstrip("/")
boto3.Session(**session).client("s3").upload_fileobj(buf, bucket, key)
return f"s3://{bucket}/{key}"
def start_training_job(session: Boto3SessionKwargs, job: TrainingJobRequest) -> str:
hp = {k: str(v) for k, v in job.hyperparameters.items()}
if job.source_dir:
s3_code_uri = _upload_source_dir(
session,
job.source_dir,
job.s3_output_path,
job.job_name,
)
hp["sagemaker_program"] = job.entry_point or "train.py"
hp["sagemaker_submit_directory"] = s3_code_uri
resource_config: ResourceConfigTypeDef = {
"InstanceType": job.instance_type,
"InstanceCount": job.instance_count,
"VolumeSizeInGB": 30,
}
request: CreateTrainingJobRequestTypeDef = {
"TrainingJobName": job.job_name,
"AlgorithmSpecification": {"TrainingImage": job.image_uri, "TrainingInputMode": "File"},
"RoleArn": job.role_arn,
"InputDataConfig": [
{
"ChannelName": "train",
"DataSource": {
"S3DataSource": {
"S3DataType": "S3Prefix",
"S3Uri": job.s3_train_uri,
"S3DataDistributionType": "FullyReplicated",
}
},
}
],
"OutputDataConfig": {"S3OutputPath": job.s3_output_path},
"ResourceConfig": resource_config,
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
"HyperParameters": hp,
}
_sm(session).create_training_job(**request)
return job.job_name
def get_training_job_status(session: Boto3SessionKwargs, job_name: str) -> TrainingJobStatus:
resp = _sm(session).describe_training_job(TrainingJobName=job_name)
return TrainingJobStatus(
name=resp["TrainingJobName"],
status=resp["TrainingJobStatus"],
created=resp.get("CreationTime"),
modified=resp.get("LastModifiedTime"),
model_artifacts=resp.get("ModelArtifacts", {}).get("S3ModelArtifacts"),
failure_reason=resp.get("FailureReason"),
raw=dict(resp),
)
def get_model_artifacts(region: str, profile: str, job_name: str) -> str:
resp = boto3.Session(profile_name=profile, region_name=region).client("sagemaker").describe_training_job(
TrainingJobName=job_name
)
artifact = resp.get("ModelArtifacts", {}).get("S3ModelArtifacts")
if not artifact:
raise RuntimeError(f"Training job '{job_name}' does not have model artifacts yet.")
return str(artifact)
def list_training_jobs(
session: Boto3SessionKwargs,
max_results: int = 10,
) -> list[TrainingJobSummaryTypeDef]:
resp = _sm(session).list_training_jobs(
SortBy="CreationTime",
SortOrder="Descending",
MaxResults=max_results,
)
return list(resp["TrainingJobSummaries"])

1
src/cloud/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Cloud provider adapters."""

77
src/cloud/mlflow.py Normal file
View File

@@ -0,0 +1,77 @@
from contextlib import AbstractContextManager
from dataclasses import dataclass
from typing import Any, Protocol
from src.aws import mlflow as aws_mlflow
from src.config import Config
class MlflowTrackingBackend(Protocol):
@property
def provider_name(self) -> str: ...
@property
def profile(self) -> str: ...
@property
def region(self) -> str: ...
def get_tracking_uri(self, tracking_server_name: str) -> str: ...
def auth_env(self) -> AbstractContextManager[None]: ...
def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]: ...
def training_run_tags(self, training_job: Any) -> dict[str, Any]: ...
def training_status_params(self, training_job_status: Any) -> dict[str, Any]: ...
def model_version_tags(self, training_job_status: Any) -> dict[str, Any]: ...
@dataclass(frozen=True)
class AwsMlflowTrackingBackend:
profile: str
region: str
provider_name: str = "aws"
def get_tracking_uri(self, tracking_server_name: str) -> str:
return aws_mlflow.get_tracking_server_arn(self.region, self.profile, tracking_server_name)
def auth_env(self) -> AbstractContextManager[None]:
return aws_mlflow.tracking_auth_env(self.profile, self.region)
def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]:
return {
"provider.name": self.provider_name,
"provider.region": region,
"provider.profile": profile,
"sagemaker.role_arn": role_arn,
"sagemaker.job_name": training_job.job_name,
"sagemaker.training_image": training_job.image_uri,
"sagemaker.instance_type": training_job.instance_type,
"sagemaker.instance_count": training_job.instance_count,
"sagemaker.s3_train_uri": training_job.s3_train_uri,
"sagemaker.s3_output_path": training_job.s3_output_path,
"sagemaker.entry_point": training_job.entry_point,
"sagemaker.source_dir": training_job.source_dir,
}
def training_run_tags(self, training_job: Any) -> dict[str, Any]:
return {"sagemaker.job_name": training_job.job_name}
def training_status_params(self, training_job_status: Any) -> dict[str, Any]:
return {
"sagemaker.training_status": training_job_status.status,
"sagemaker.created_at": training_job_status.created,
"sagemaker.modified_at": training_job_status.modified,
"sagemaker.model_artifacts": training_job_status.model_artifacts,
"sagemaker.failure_reason": training_job_status.failure_reason,
}
def model_version_tags(self, training_job_status: Any) -> dict[str, Any]:
return {"sagemaker.job_name": training_job_status.name}
def mlflow_tracking_backend_from_config(cfg: Config) -> MlflowTrackingBackend:
return AwsMlflowTrackingBackend(profile=cfg.aws.profile, region=cfg.aws.region)

567
src/commands/ai_hub.py Normal file
View File

@@ -0,0 +1,567 @@
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from datetime import datetime
from enum import StrEnum
from pathlib import Path
from typing import Any
import qai_hub.hub as hub
import typer
from qai_hub.client import Device
from src import state as state_ops
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
from src.config import Config
from src.qualcomm import aihub_jobs
from src.qualcomm.artifacts import ResolvedOnnx, resolve_onnx
app = typer.Typer(help="Optimize, quantize, compile, validate, profile, and download models with Qualcomm Workbench")
_RUNTIME_EXTENSIONS = {
"tflite": "tflite",
"qnn_context_binary": "bin",
"onnx": "onnx",
}
class UploadStep(StrEnum):
optimize = "optimize"
quantize = "quantize"
compile = "compile"
validate = "validate"
profile = "profile"
@dataclass(frozen=True)
class ResolvedModelSource:
model: str | Path
model_artifact: str | None = None
def _input_specs(cfg: Config) -> dict[str, tuple[tuple[int, ...], str]]:
specs = {name: (tuple(shape), dtype) for name, (shape, dtype) in cfg.aihub.input_specs.items()}
if not specs:
CONSOLE.print("[red]aihub.input_specs must define at least one input.[/red]")
raise typer.Exit(1)
return specs
def _load_inputs(
input_file: Path,
specs: Mapping[str, tuple[Sequence[int], str]],
input_name: str | None = None,
) -> dict[str, Any]:
import numpy as np
if not input_file.exists():
raise FileNotFoundError(f"File not found: {input_file}")
if input_file.suffix == ".npz":
loaded = np.load(input_file)
missing = set(specs) - set(loaded.files)
if missing:
raise ValueError(f"Missing input(s) in NPZ: {', '.join(sorted(missing))}")
return {name: loaded[name] for name in specs}
if input_file.suffix == ".npy":
if input_name is None:
if len(specs) != 1:
raise ValueError("--input-name is required when config has multiple inputs")
input_name = next(iter(specs))
if input_name not in specs:
raise ValueError(f"Input name '{input_name}' is not defined in aihub.input_specs")
return {input_name: np.load(input_file)}
raise ValueError("Input file must be .npz or .npy")
def _load_calibration(path: Path, specs: Mapping[str, tuple[Sequence[int], str]]) -> dict[str, Any]:
import numpy as np
if path.is_file():
return _load_inputs(path, specs)
if not path.is_dir():
raise FileNotFoundError(f"Calibration path not found: {path}")
if len(specs) != 1:
raise ValueError("Directory calibration data is supported only for single-input models.")
input_name = next(iter(specs))
samples = [np.load(p) for p in sorted(path.glob("*.npy"))]
if not samples:
raise ValueError(f"No .npy calibration samples found in {path}")
return {input_name: samples}
def _job_name(cfg: Config, operation: str) -> str | None:
if not cfg.aihub.job_name:
return None
return f"{cfg.aihub.job_name}-{operation}"
def _model_id_or_state(config_path: str, model_id: str | None, *, quantized: bool = False) -> str:
st = state_ops.store(config_path)
resolved = model_id or (st.get_last_quantized_model_id() if quantized else st.get_last_compiled_model_id())
if not resolved:
source = "quantized" if quantized else "compiled"
CONSOLE.print(f"[red]No {source} model found. Pass --model-id or run the previous AI Hub step first.[/red]")
raise typer.Exit(1)
return resolved
def _resolve_model_source(
cfg: Config,
config_path: str,
*,
model_id: str | None = None,
previous_model_id: str | None = None,
from_job: str | None = None,
model_s3_uri: str | None = None,
onnx_path: str | None = None,
) -> ResolvedModelSource:
if model_id:
return ResolvedModelSource(model_id)
has_explicit_source = bool(from_job or model_s3_uri or onnx_path)
if previous_model_id and not has_explicit_source:
return ResolvedModelSource(previous_model_id)
resolved = _resolve_onnx_source(
cfg,
config_path,
from_job=from_job,
model_s3_uri=model_s3_uri,
onnx_path=onnx_path,
)
return ResolvedModelSource(resolved.onnx_path, resolved.model_artifact)
def _resolve_onnx_source(
cfg: Config,
config_path: str,
*,
from_job: str | None = None,
model_s3_uri: str | None = None,
onnx_path: str | None = None,
) -> ResolvedOnnx:
st = state_ops.store(config_path)
last_training_job = st.get_last_training_job()
saved_model_artifact = None
if not from_job and not model_s3_uri and not onnx_path and not last_training_job:
saved_model_artifact = st.get_last_model_artifact()
return resolve_onnx(
cfg=cfg,
output_dir=cfg.aihub.output_dir,
from_job=from_job,
model_s3_uri=model_s3_uri or saved_model_artifact,
onnx_path=onnx_path,
last_training_job=last_training_job,
)
def _device_selector(device: Device) -> str:
parts: list[str] = []
if device.name:
parts.append(f"name={device.name!r}")
if device.os:
parts.append(f"os={device.os!r}")
if device.attributes:
parts.append(f"attributes={device.attributes!r}")
return ", ".join(parts) if parts else "empty selector"
def _validate_device(cfg: Config) -> None:
device = cfg.aihub.device
try:
matches = hub.get_devices(name=device.name, os=device.os, attributes=device.attributes)
except Exception as e:
CONSOLE.print(f"[red]Unable to validate AI Hub device {_device_selector(device)}: {e}[/red]")
raise typer.Exit(1)
if matches:
return
CONSOLE.print(f"[red]AI Hub device not found: {_device_selector(device)}[/red]")
CONSOLE.print("Run [bold]qai-hub list-devices[/bold] to see valid device names.")
raise typer.Exit(1)
def _quantize_step(
cfg: Config,
config_path: str,
calibration_path: Path,
*,
model_id: str | None = None,
from_job: str | None = None,
model_s3_uri: str | None = None,
onnx_path: str | None = None,
) -> str:
st = state_ops.store(config_path)
specs = _input_specs(cfg)
try:
source = _resolve_model_source(
cfg,
config_path,
model_id=model_id,
previous_model_id=st.get_last_optimized_model_id(),
from_job=from_job,
model_s3_uri=model_s3_uri,
onnx_path=onnx_path,
)
calibration_data = _load_calibration(calibration_path, specs)
except (FileNotFoundError, ValueError) as e:
CONSOLE.print(f"[red]{e}[/red]")
raise typer.Exit(1)
try:
hub_model = (
hub.upload_model(str(source.model), name=cfg.aihub.model_name)
if isinstance(source.model, Path)
else hub.get_model(source.model)
)
result = aihub_jobs.submit_quantize_job(
hub_model,
calibration_data,
cfg.aihub.quantize_options,
job_name=_job_name(cfg, "quantize"),
)
except Exception as e:
CONSOLE.print(f"[red]AI Hub quantize failed: {e}[/red]")
raise typer.Exit(1)
updates: dict[str, Any] = {
"last_quantize_job_id": result["job_id"],
"last_quantized_model_id": result["model_id"],
}
if source.model_artifact:
updates["last_model_artifact"] = source.model_artifact
st.update(**updates)
CONSOLE.print(f"[green]✓[/green] Quantize job: [bold]{result['job_id']}[/bold]")
CONSOLE.print(f"[green]✓[/green] Quantized model: [bold]{result['model_id']}[/bold]")
return str(result["model_id"])
def _optimize_step(
cfg: Config,
config_path: str,
from_job: str | None,
model_s3_uri: str | None,
onnx_path: str | None,
) -> str:
st = state_ops.store(config_path)
_validate_device(cfg)
specs = _input_specs(cfg)
try:
source = _resolve_onnx_source(
cfg,
config_path,
from_job=from_job,
model_s3_uri=model_s3_uri,
onnx_path=onnx_path,
)
except (FileNotFoundError, ValueError) as e:
CONSOLE.print(f"[red]{e}[/red]")
raise typer.Exit(1)
try:
hub_model = hub.upload_model(str(source.onnx_path), name=cfg.aihub.model_name)
result = aihub_jobs.submit_compile_job(
model=hub_model,
device=cfg.aihub.device,
input_specs=specs,
target_runtime="onnx",
job_name=_job_name(cfg, "optimize"),
)
except Exception as e:
CONSOLE.print(f"[red]AI Hub ONNX optimization failed: {e}[/red]")
raise typer.Exit(1)
st.update(
last_model_artifact=source.model_artifact,
last_optimize_job_id=result["job_id"],
last_optimized_model_id=result["model_id"],
)
CONSOLE.print(f"[green]✓[/green] ONNX optimization job: [bold]{result['job_id']}[/bold]")
CONSOLE.print(f"[green]✓[/green] Optimized ONNX model: [bold]{result['model_id']}[/bold]")
return str(result["model_id"])
def _compile_step(
cfg: Config,
config_path: str,
*,
model_id: str | None = None,
from_job: str | None = None,
model_s3_uri: str | None = None,
onnx_path: str | None = None,
) -> str:
st = state_ops.store(config_path)
_validate_device(cfg)
specs = _input_specs(cfg)
try:
source = _resolve_model_source(
cfg,
config_path,
model_id=model_id,
previous_model_id=st.get_last_quantized_model_id(),
from_job=from_job,
model_s3_uri=model_s3_uri,
onnx_path=onnx_path,
)
except (FileNotFoundError, ValueError) as e:
CONSOLE.print(f"[red]{e}[/red]")
raise typer.Exit(1)
try:
hub_model = (
hub.upload_model(str(source.model), name=cfg.aihub.model_name)
if isinstance(source.model, Path)
else hub.get_model(source.model)
)
result = aihub_jobs.submit_compile_job(
model=hub_model,
device=cfg.aihub.device,
input_specs=specs,
target_runtime=cfg.aihub.target_runtime,
options=cfg.aihub.compile_options,
job_name=_job_name(cfg, "compile"),
)
except Exception as e:
CONSOLE.print(f"[red]AI Hub compile failed: {e}[/red]")
raise typer.Exit(1)
updates: dict[str, Any] = {
"last_compile_job_id": result["job_id"],
"last_compiled_model_id": result["model_id"],
}
if source.model_artifact:
updates["last_model_artifact"] = source.model_artifact
st.update(**updates)
CONSOLE.print(f"[green]✓[/green] Compile job: [bold]{result['job_id']}[/bold]")
CONSOLE.print(f"[green]✓[/green] Compiled model: [bold]{result['model_id']}[/bold]")
return str(result["model_id"])
def _validate_step(
cfg: Config,
config_path: str,
input_file: Path,
model_id: str | None,
input_name: str | None,
) -> str:
_validate_device(cfg)
specs = _input_specs(cfg)
resolved_model_id = _model_id_or_state(config_path, model_id)
try:
inputs = _load_inputs(input_file, specs, input_name)
except (FileNotFoundError, ValueError) as e:
CONSOLE.print(f"[red]{e}[/red]")
raise typer.Exit(1)
run = datetime.now().strftime("%Y%m%d-%H%M%S")
out_dir = Path(cfg.aihub.output_dir) / run / "validation"
try:
hub_model = hub.get_model(resolved_model_id)
result = aihub_jobs.submit_inference_job(
hub_model,
cfg.aihub.device,
inputs,
out_dir,
job_name=_job_name(cfg, "validate"),
)
except Exception as e:
CONSOLE.print(f"[red]AI Hub inference failed: {e}[/red]")
raise typer.Exit(1)
state_ops.store(config_path).update(last_inference_job_id=result["job_id"])
CONSOLE.print(f"[green]✓[/green] Inference job: [bold]{result['job_id']}[/bold]")
outputs = result.get("outputs")
if isinstance(outputs, dict):
for name, value in outputs.items():
CONSOLE.print(f" {name}: shape={getattr(value, 'shape', '?')}")
CONSOLE.print(f"Outputs: [cyan]{out_dir}[/cyan]")
return str(result["job_id"])
def _profile_step(cfg: Config, config_path: str, model_id: str | None) -> str:
_validate_device(cfg)
resolved_model_id = _model_id_or_state(config_path, model_id)
try:
hub_model = hub.get_model(resolved_model_id)
result = aihub_jobs.submit_profile_job(
hub_model,
cfg.aihub.device,
cfg.aihub.profile_options,
job_name=_job_name(cfg, "profile"),
)
except Exception as e:
CONSOLE.print(f"[red]AI Hub profile failed: {e}[/red]")
raise typer.Exit(1)
state_ops.store(config_path).update(last_profile_job_id=result["job_id"])
CONSOLE.print(f"[green]✓[/green] Profile job: [bold]{result['job_id']}[/bold]")
return str(result["job_id"])
@app.command()
def optimize(
from_job: str | None = typer.Option(None, "--from-job", help="Training job name whose model artifact should optimize"),
model_s3_uri: str | None = typer.Option(None, "--model-s3-uri", help="S3 URI of model.tar.gz to optimize"),
onnx_path: str | None = typer.Option(
None, "--onnx-path", help="Local ONNX path or ONNX path inside extracted artifact"
),
config: str = CONFIG_OPT,
) -> None:
"""Optimize an external model into a Workbench-produced ONNX model."""
cfg = load_cfg(config)
_optimize_step(cfg, config, from_job, model_s3_uri, onnx_path)
@app.command()
def quantize(
calibration_path: Path = typer.Argument(..., help="Calibration .npz file or directory of .npy samples"),
model_id: str | None = typer.Option(None, "--model-id", help="AI Hub optimized ONNX model ID"),
from_job: str | None = typer.Option(None, "--from-job", help="Training job name whose model artifact should quantize"),
model_s3_uri: str | None = typer.Option(None, "--model-s3-uri", help="S3 URI of model.tar.gz to quantize"),
onnx_path: str | None = typer.Option(
None, "--onnx-path", help="Local ONNX path or ONNX path inside extracted artifact"
),
config: str = CONFIG_OPT,
) -> None:
"""Quantize an ONNX model to INT8."""
cfg = load_cfg(config)
_quantize_step(
cfg,
config,
calibration_path,
model_id=model_id,
from_job=from_job,
model_s3_uri=model_s3_uri,
onnx_path=onnx_path,
)
@app.command()
def compile(
model_id: str | None = typer.Option(None, "--model-id", help="AI Hub model ID to compile"),
from_job: str | None = typer.Option(None, "--from-job", help="Training job name whose model artifact should compile"),
model_s3_uri: str | None = typer.Option(None, "--model-s3-uri", help="S3 URI of model.tar.gz to compile"),
onnx_path: str | None = typer.Option(
None, "--onnx-path", help="Local ONNX path or ONNX path inside extracted artifact"
),
config: str = CONFIG_OPT,
) -> None:
"""Compile a model for the configured Qualcomm AI Hub target."""
cfg = load_cfg(config)
_compile_step(
cfg,
config,
model_id=model_id,
from_job=from_job,
model_s3_uri=model_s3_uri,
onnx_path=onnx_path,
)
@app.command()
def validate(
input_file: Path = typer.Argument(..., help="NumPy .npz or .npy inputs to run on device"),
model_id: str | None = typer.Option(None, "--model-id", help="AI Hub compiled model ID"),
input_name: str | None = typer.Option(None, "--input-name", help="Input name for .npy files"),
config: str = CONFIG_OPT,
) -> None:
"""Run an AI Hub inference job using sample inputs."""
cfg = load_cfg(config)
_validate_step(cfg, config, input_file, model_id, input_name)
@app.command()
def profile(
model_id: str | None = typer.Option(None, "--model-id", help="AI Hub compiled model ID"),
config: str = CONFIG_OPT,
) -> None:
"""Profile a compiled model on the configured AI Hub device."""
cfg = load_cfg(config)
_profile_step(cfg, config, model_id)
@app.command()
def upload(
calibration_path: Path = typer.Argument(..., help="Calibration .npz file or directory of .npy samples"),
input_file: Path = typer.Argument(..., help="Validation .npz or .npy inputs to run on device"),
from_step: UploadStep = typer.Option(UploadStep.optimize, "--from-step", help="Resume from this Workbench step"),
from_job: str | None = typer.Option(None, "--from-job", help="Training job name whose model artifact should upload"),
model_s3_uri: str | None = typer.Option(None, "--model-s3-uri", help="S3 URI of model.tar.gz to upload"),
onnx_path: str | None = typer.Option(
None, "--onnx-path", help="Local ONNX path or ONNX path inside extracted artifact"
),
input_name: str | None = typer.Option(None, "--input-name", help="Input name for .npy validation files"),
config: str = CONFIG_OPT,
) -> None:
"""Optimize, quantize, optionally compile, validate, and profile a model."""
cfg = load_cfg(config)
steps = [UploadStep.optimize, UploadStep.quantize, UploadStep.compile, UploadStep.validate, UploadStep.profile]
selected = steps[steps.index(from_step) :]
optimized_model_id: str | None = None
quantized_model_id: str | None = None
compiled_model_id: str | None = None
if UploadStep.optimize in selected:
optimized_model_id = _optimize_step(cfg, config, from_job, model_s3_uri, onnx_path)
if UploadStep.quantize in selected:
if UploadStep.optimize not in selected:
optimized_model_id = state_ops.store(config).get_last_optimized_model_id()
if not optimized_model_id:
CONSOLE.print(
"[red]No optimized ONNX model found. Resume from --from-step optimize or run "
"'qc-cli ai-hub optimize' first.[/red]"
)
raise typer.Exit(1)
quantized_model_id = _quantize_step(
cfg,
config,
calibration_path,
model_id=optimized_model_id,
)
if UploadStep.compile in selected:
if cfg.aihub.target_runtime == "onnx":
compiled_model_id = quantized_model_id or state_ops.store(config).get_last_quantized_model_id()
if not compiled_model_id:
CONSOLE.print(
"[red]No quantized ONNX model found. Resume from --from-step quantize or run "
"'qc-cli ai-hub quantize' first.[/red]"
)
raise typer.Exit(1)
state_ops.store(config).update(last_compiled_model_id=compiled_model_id)
CONSOLE.print("[green]✓[/green] Target runtime is ONNX; skipping final compile.")
else:
compiled_model_id = _compile_step(
cfg,
config,
model_id=quantized_model_id,
)
if UploadStep.validate in selected:
_validate_step(cfg, config, input_file, compiled_model_id, input_name)
if UploadStep.profile in selected:
_profile_step(cfg, config, compiled_model_id)
@app.command()
def download(
model_id: str | None = typer.Option(None, "--model-id", help="AI Hub compiled model ID"),
output: Path | None = typer.Option(None, "--output", "-o", help="Destination file path"),
config: str = CONFIG_OPT,
) -> None:
"""Download the last compiled deployable artifact from AI Hub."""
cfg = load_cfg(config)
resolved_model_id = _model_id_or_state(config, model_id)
ext = _RUNTIME_EXTENSIONS.get(cfg.aihub.target_runtime, cfg.aihub.target_runtime)
dest = output or (Path(cfg.aihub.output_dir) / f"model.{ext}")
try:
written = aihub_jobs.download_model(resolved_model_id, dest)
except Exception as e:
CONSOLE.print(f"[red]AI Hub download failed: {e}[/red]")
raise typer.Exit(1)
state_ops.store(config).update(last_downloaded_model=written)
CONSOLE.print(f"[green]✓[/green] Downloaded model: [cyan]{written}[/cyan]")

View File

@@ -51,6 +51,8 @@ def setup(
profile=cfg.aws.profile,
account_id=account_id,
region=cfg.aws.region,
bootstrap_qualifier=cfg.infra.effective_bootstrap_qualifier,
toolkit_stack_name=cfg.infra.effective_toolkit_stack_name,
cloudformation_execution_policy=cloudformation_execution_policy,
)
with CONSOLE.status("Running cdk deploy..."):
@@ -58,6 +60,9 @@ def setup(
profile=cfg.aws.profile,
account_id=account_id,
region=cfg.aws.region,
stack_name=cfg.infra.stack_name,
bootstrap_qualifier=cfg.infra.effective_bootstrap_qualifier,
toolkit_stack_name=cfg.infra.effective_toolkit_stack_name,
config_path=config,
config_dir=str(Path(config).parent),
config_snapshot=cfg.model_dump(mode="json"),
@@ -72,7 +77,8 @@ def setup(
if outputs.get("SageMakerRoleArn"):
CONSOLE.print(f"[green]✓[/green] IAM role: {outputs['SageMakerRoleArn']}")
if cfg.mlflow.mode is MlflowMode.create and outputs.get("MlflowTrackingServerArn"):
CONSOLE.print(f"[green]✓[/green] MLflow: {outputs['MlflowTrackingServerArn']}")
mlflow_name = outputs.get("MlflowTrackingServerName", cfg.managed_mlflow_tracking_server_name)
CONSOLE.print(f"[green]✓[/green] MLflow: {mlflow_name}")
elif cfg.mlflow.mode is MlflowMode.existing:
CONSOLE.print(f"[green]✓[/green] MLflow: {cfg.mlflow.tracking_server_name}")
CONSOLE.print("\n[bold green]Infrastructure ready.[/bold green]")
@@ -82,7 +88,7 @@ def setup(
def status(config: str = CONFIG_OPT) -> None:
"""Show current infrastructure status."""
cfg = load_cfg(config)
stack = cloudformation.stack_status(cfg.aws.region, cfg.aws.profile)
stack = cloudformation.stack_status(cfg.aws.region, cfg.aws.profile, cfg.infra.stack_name)
table = Table(title="Infrastructure Status")
table.add_column("Resource", style="cyan")
@@ -91,13 +97,13 @@ def status(config: str = CONFIG_OPT) -> None:
table.add_column("ARN / URI")
if not stack:
table.add_row("CDK Stack", provisioning.STACK_NAME, "[red]missing[/red]", "-")
table.add_row("CDK Stack", cfg.infra.stack_name, "[red]missing[/red]", "-")
table.add_row("S3 Bucket", cfg.s3.bucket, "[red]unknown[/red]", "-")
table.add_row("IAM Role", cfg.sagemaker.role_name, "[red]unknown[/red]", "-")
if cfg.mlflow.mode is not MlflowMode.disabled:
table.add_row(
"MLflow",
cfg.mlflow.tracking_server_name or "-",
cfg.effective_mlflow_tracking_server_name or "-",
"[red]unknown[/red]",
"-",
)
@@ -114,14 +120,14 @@ def status(config: str = CONFIG_OPT) -> None:
)
table.add_row(
"IAM Role",
cfg.sagemaker.role_name,
_role_name(cfg.sagemaker.role_name, outputs.get("SageMakerRoleArn", "")),
"[green]managed[/green]",
outputs.get("SageMakerRoleArn", "-"),
)
if cfg.mlflow.mode is MlflowMode.create:
table.add_row(
"MLflow",
cfg.mlflow.tracking_server_name or "-",
outputs.get("MlflowTrackingServerName", cfg.managed_mlflow_tracking_server_name),
"[green]managed[/green]",
outputs.get("MlflowTrackingServerArn", outputs.get("MlflowArtifactUri", "-")),
)
@@ -156,10 +162,13 @@ def destroy(
) -> None:
"""Destroy the CDK stack."""
cfg = _destroy_config(config)
stack_name = _destroy_stack_name(config, cfg)
bootstrap_qualifier = _destroy_bootstrap_qualifier(config, cfg)
toolkit_stack_name = _destroy_toolkit_stack_name(config, cfg)
if not yes and not delete_bucket_data:
typer.confirm(
f"Destroy CDK stack '{provisioning.STACK_NAME}' while retaining S3 bucket data?",
f"Destroy CDK stack '{stack_name}' while retaining S3 bucket data?",
abort=True,
)
@@ -172,13 +181,17 @@ def destroy(
provisioning.destroy(
profile=cfg.aws.profile,
account_id=account_id,
stack_name=stack_name,
bootstrap_qualifier=bootstrap_qualifier,
toolkit_stack_name=toolkit_stack_name,
config_path=str(snapshot_path),
delete_bucket_data=delete_bucket_data,
)
except RuntimeError as e:
CONSOLE.print(f"[red]{e}[/red]")
raise typer.Exit(1)
CONSOLE.print(f"[green]✓[/green] Destroyed stack: {provisioning.STACK_NAME}")
CONSOLE.print(f"[green]✓[/green] Destroyed stack: {stack_name}")
CONSOLE.print(f"[yellow]CDK bootstrap stack retained: {toolkit_stack_name}[/yellow]")
def _destroy_config(config_path: str) -> Config:
@@ -190,6 +203,14 @@ def _destroy_config(config_path: str) -> Config:
return load_cfg(config_path)
def _role_name(configured_name: str, role_arn: str) -> str:
if configured_name:
return configured_name
if role_arn:
return role_arn.rsplit("/", 1)[-1]
return "-"
def _destroy_account_id(config_path: str, cfg: Config) -> str:
config_dir = str(Path(config_path).parent)
state = read_infra_state(config_dir)
@@ -197,3 +218,30 @@ def _destroy_account_id(config_path: str, cfg: Config) -> str:
if account_id:
return str(account_id)
return identity.account_id(cfg.aws.region, cfg.aws.profile)
def _destroy_stack_name(config_path: str, cfg: Config) -> str:
config_dir = str(Path(config_path).parent)
state = read_infra_state(config_dir)
stack_name = state.get("stack_name")
if stack_name:
return str(stack_name)
return cfg.infra.stack_name
def _destroy_bootstrap_qualifier(config_path: str, cfg: Config) -> str:
config_dir = str(Path(config_path).parent)
state = read_infra_state(config_dir)
bootstrap_qualifier = state.get("bootstrap_qualifier")
if bootstrap_qualifier:
return str(bootstrap_qualifier)
return cfg.infra.effective_bootstrap_qualifier
def _destroy_toolkit_stack_name(config_path: str, cfg: Config) -> str:
config_dir = str(Path(config_path).parent)
state = read_infra_state(config_dir)
toolkit_stack_name = state.get("toolkit_stack_name")
if toolkit_stack_name:
return str(toolkit_stack_name)
return cfg.infra.effective_toolkit_stack_name

40
src/commands/init.py Normal file
View File

@@ -0,0 +1,40 @@
import secrets
from pathlib import Path
import typer
import yaml
from src.commands.utils import CONSOLE
from src.config import GENERATED_STACK_PREFIX, Config, InfraConfig, S3Config
app = typer.Typer()
@app.command()
def init(
output: str = typer.Option("config.yaml", help="Destination path for the config file"),
force: bool = typer.Option(False, "--force", "-f", help="Overwrite an existing config file"),
) -> None:
"""Write a starter config.yaml to the current directory."""
dest = Path(output)
if dest.exists() and not force:
CONSOLE.print(f"[yellow]{dest} already exists.[/yellow] Use --force to overwrite.")
raise typer.Exit(1)
config = _new_isolated_config()
dest.parent.mkdir(parents=True, exist_ok=True)
config_data = config.model_dump(mode="json")
config_data["sagemaker"].pop("role_name", None)
with open(dest, "w") as f:
yaml.safe_dump(config_data, f, sort_keys=False)
CONSOLE.print(f"[green]✓[/green] Config written to [bold]{dest}[/bold]")
CONSOLE.print("Edit [cyan]sagemaker.training.image_uri[/cyan] before running training commands.")
def _new_isolated_config() -> Config:
suffix = secrets.token_hex(6)
namespace = f"{GENERATED_STACK_PREFIX}{suffix}"
config = Config(infra=InfraConfig(stack_name=namespace))
config.s3 = S3Config(bucket=f"{namespace}-data")
return config

95
src/commands/mlflow.py Normal file
View File

@@ -0,0 +1,95 @@
import webbrowser
import typer
from src import state as state_ops
from src.aws import mlflow as aws_mlflow
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
from src.config import MlflowMode
from src.tracking.upload import upload_training_metrics
app = typer.Typer(help="Manage MLflow tracking server access")
@app.command(name="open")
def open_mlflow(config: str = CONFIG_OPT) -> None:
"""Open a presigned URL for the configured MLflow tracking server."""
cfg = load_cfg(config)
tracking_server_name = cfg.effective_mlflow_tracking_server_name
if not tracking_server_name:
CONSOLE.print("[red]MLflow is disabled in config.yaml.[/red]")
raise typer.Exit(1)
try:
url = aws_mlflow.create_presigned_tracking_server_url(
cfg.aws.region,
cfg.aws.profile,
tracking_server_name,
)
except Exception as e:
CONSOLE.print("[yellow]Could not create a SageMaker MLflow UI URL.[/yellow]")
CONSOLE.print(f"Tracking server: [cyan]{tracking_server_name}[/cyan]")
CONSOLE.print(f"Reason: {e}")
CONSOLE.print(
"This command can create presigned URLs only for MLflow tracking servers managed by "
"Amazon SageMaker. If this is an external MLflow server, open it with that server's own URL."
)
raise typer.Exit(1)
CONSOLE.print(f"MLflow tracking server: [cyan]{tracking_server_name}[/cyan]")
CONSOLE.print(f"MLflow UI: {url}")
if webbrowser.open(url):
CONSOLE.print("[green]✓[/green] Opened MLflow UI in your browser.")
else:
CONSOLE.print("[yellow]Could not open a browser automatically. Open the URL above manually.[/yellow]")
@app.command(name="upload-metrics")
def upload_metrics(
job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"),
force: bool = typer.Option(False, "--force", help="Upload again even if this job is marked as uploaded"),
config: str = CONFIG_OPT,
) -> None:
"""Upload a completed training job's metric history to MLflow."""
cfg = load_cfg(config)
if cfg.mlflow.mode is MlflowMode.disabled:
CONSOLE.print("[red]MLflow is disabled in config.yaml.[/red]")
raise typer.Exit(1)
st = state_ops.store(config)
if not job_name:
job_name = st.get_last_training_job()
if not job_name:
CONSOLE.print(
"[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]"
)
raise typer.Exit(1)
if st.get_training_job(job_name).get("mlflow_metrics_uploaded") and not force:
CONSOLE.print(f"[green]✓[/green] Metrics already uploaded for [cyan]{job_name}[/cyan].")
return
try:
result = upload_training_metrics(
job_name=job_name,
config_path=config,
cfg=cfg,
force=force,
)
except Exception as e:
CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]")
raise typer.Exit(1)
if result.metrics_history_uploaded:
CONSOLE.print(f"[green]✓[/green] Uploaded training metrics for [cyan]{job_name}[/cyan].")
else:
CONSOLE.print(
f"[yellow]No training_metrics.json was found in the SageMaker model artifact for "
f"[cyan]{job_name}[/cyan]. Uploaded SageMaker final metrics only.[/yellow]"
)
CONSOLE.print(f"MLflow run: [cyan]{result.run_id}[/cyan]")
if result.registered_model_version:
CONSOLE.print(
f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] "
"([cyan]experiment-latest[/cyan])"
)

265
src/commands/train.py Normal file
View File

@@ -0,0 +1,265 @@
import time
from datetime import datetime
from pathlib import Path
import typer
from rich.table import Table
from src import state as state_ops
from src.aws import iam
from src.aws import sagemaker as sm_ops
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
from src.config import Config, MlflowMode
from src.infra.state import read_infra_state
from src.tracking.mlflow import MlflowTracker
from src.tracking.upload import upload_training_metrics
app = typer.Typer(help="Manage SageMaker training jobs")
_STATUS_COLOR = {
"Completed": "green",
"Failed": "red",
"InProgress": "yellow",
"Stopping": "yellow",
"Stopped": "dim",
}
_TERMINAL_STATUSES = {"Completed", "Failed", "Stopped"}
DEFAULT_POLL_INTERVAL_SECONDS = 30
def _tracker(cfg):
try:
return MlflowTracker.from_config(cfg)
except Exception as e:
CONSOLE.print(f"[red]MLflow setup failed: {e}[/red]")
raise typer.Exit(1)
def _config_dir(config_path: str) -> str:
return str(Path(config_path).parent)
def _sagemaker_role_arn(config_path: str, cfg: Config) -> str:
state = read_infra_state(_config_dir(config_path))
role_arn = state.get("outputs", {}).get("SageMakerRoleArn")
if role_arn:
return str(role_arn)
if cfg.sagemaker.role_name:
role_arn = iam.get_role_arn(cfg.aws.profile, cfg.sagemaker.role_name)
if role_arn:
return role_arn
raise RuntimeError(f"IAM role '{cfg.sagemaker.role_name}' not found. Run 'qc-cli infra setup' first.")
raise RuntimeError("SageMaker role not found in infra state. Run 'qc-cli infra setup' first.")
def _print_training_status(status: sm_ops.TrainingJobStatus) -> None:
color = _STATUS_COLOR.get(status.status, "white")
CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]")
CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]")
if status.created:
CONSOLE.print(f"Created: {status.created}")
if status.model_artifacts:
CONSOLE.print(f"Artifacts: {status.model_artifacts}")
if status.failure_reason:
CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]")
def _wait_and_upload_metrics(
*,
job_name: str,
poll_interval: int,
config_path: str,
cfg: Config,
) -> None:
st = state_ops.store(config_path)
previous_status: str | None = None
try:
while True:
training_status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
if training_status.status != previous_status:
color = _STATUS_COLOR.get(training_status.status, "white")
CONSOLE.print(
f"Job [cyan]{training_status.name}[/cyan]: "
f"[{color}]{training_status.status}[/{color}]"
)
previous_status = training_status.status
if training_status.status in _TERMINAL_STATUSES:
_print_training_status(training_status)
if training_status.status != "Completed":
raise typer.Exit(1)
try:
result = upload_training_metrics(
job_name=job_name,
config_path=config_path,
cfg=cfg,
)
if result.metrics_history_uploaded:
CONSOLE.print(
f"[green]✓[/green] Uploaded training metrics to MLflow run "
f"[cyan]{result.run_id}[/cyan]."
)
else:
CONSOLE.print(
"[yellow]No training_metrics.json was found in the SageMaker model artifact. "
"Uploaded SageMaker final metrics only.[/yellow]"
)
if result.registered_model_version:
CONSOLE.print(
f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] "
"([cyan]experiment-latest[/cyan])"
)
except Exception as e:
CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]")
CONSOLE.print(
f"Retry with [cyan]qc-cli mlflow upload-metrics {job_name}[/cyan]."
)
raise typer.Exit(1)
job_state = st.get_training_job(job_name)
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
return
time.sleep(poll_interval)
except KeyboardInterrupt:
CONSOLE.print("[yellow]Stopped waiting. The SageMaker training job is still running.[/yellow]")
raise typer.Exit(130)
@app.command()
def start(
upload_metrics: bool = typer.Option(
False,
"--upload-metrics",
help="Wait for completion, then upload training metrics to MLflow",
),
poll_interval: int = typer.Option(
DEFAULT_POLL_INTERVAL_SECONDS,
"--poll-interval",
min=1,
help="Seconds between status checks when --upload-metrics is used",
),
config: str = CONFIG_OPT,
) -> None:
"""Submit a SageMaker training job."""
cfg = load_cfg(config)
if upload_metrics and cfg.mlflow.mode is MlflowMode.disabled:
CONSOLE.print("[red]--upload-metrics requires MLflow to be enabled in config.yaml.[/red]")
raise typer.Exit(1)
if not cfg.sagemaker.training.image_uri:
CONSOLE.print("[red]sagemaker.training.image_uri is required in config.yaml.[/red]")
CONSOLE.print(
"Find pre-built images at: "
"https://aws.github.io/deep-learning-containers/reference/available_images"
)
raise typer.Exit(1)
try:
role_arn = _sagemaker_role_arn(config, cfg)
except RuntimeError as e:
CONSOLE.print(f"[red]{e}[/red]")
raise typer.Exit(1)
tracker = _tracker(cfg)
job_name = f"qc-cli-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
s3_train_uri = f"s3://{cfg.s3.bucket}/{cfg.s3.data_prefix}"
s3_output = f"s3://{cfg.s3.bucket}/{cfg.s3.model_prefix}"
CONSOLE.print(f"Submitting training job [cyan]{job_name}[/cyan]...")
training_job = sm_ops.TrainingJobRequest(
role_arn=role_arn,
image_uri=cfg.sagemaker.training.image_uri,
instance_type=cfg.sagemaker.training.instance_type,
instance_count=cfg.sagemaker.training.instance_count,
s3_train_uri=s3_train_uri,
s3_output_path=s3_output,
job_name=job_name,
hyperparameters=cfg.sagemaker.training.hyperparameters,
entry_point=cfg.sagemaker.training.entry_point,
source_dir=cfg.sagemaker.training.source_dir,
)
sm_ops.start_training_job(cfg.aws.boto3_session, training_job)
st = state_ops.store(config)
st.set_last_training_job(job_name)
try:
run_id = tracker.start_training_run(
training_job,
region=cfg.aws.region,
profile=cfg.aws.profile,
role_arn=role_arn,
)
except Exception as e:
run_id = None
CONSOLE.print(f"[yellow]MLflow run creation failed: {e}[/yellow]")
CONSOLE.print(
"The SageMaker job is still running. Upload metrics after completion with "
f"[cyan]qc-cli mlflow upload-metrics {job_name}[/cyan]."
)
if run_id:
st.update_training_job(job_name, mlflow_run_id=run_id)
CONSOLE.print(f"[green]✓[/green] Job submitted: [bold]{job_name}[/bold]")
if run_id:
CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]")
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
if upload_metrics:
_wait_and_upload_metrics(
job_name=job_name,
poll_interval=poll_interval,
config_path=config,
cfg=cfg,
)
else:
CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]")
@app.command()
def status(
job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"),
config: str = CONFIG_OPT,
) -> None:
"""Show training job status."""
cfg = load_cfg(config)
st = state_ops.store(config)
if not job_name:
job_name = st.get_last_training_job()
if not job_name:
CONSOLE.print(
"[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]"
)
raise typer.Exit(1)
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
_print_training_status(status)
@app.command(name="list")
def list_jobs(
limit: int = typer.Option(10, "--limit", "-n", help="Number of jobs to show"),
config: str = CONFIG_OPT,
) -> None:
"""List recent training jobs."""
cfg = load_cfg(config)
jobs = sm_ops.list_training_jobs(cfg.aws.boto3_session, max_results=limit)
if not jobs:
CONSOLE.print("[yellow]No training jobs found.[/yellow]")
return
table = Table(title="Training Jobs")
table.add_column("Name", style="cyan")
table.add_column("Status")
table.add_column("Created")
for job in jobs:
status_value = str(job["TrainingJobStatus"])
color = _STATUS_COLOR.get(status_value, "white")
table.add_row(
str(job["TrainingJobName"]),
f"[{color}]{status_value}[/{color}]",
str(job.get("CreationTime", "")),
)
CONSOLE.print(table)

70
src/commands/upload.py Normal file
View File

@@ -0,0 +1,70 @@
from pathlib import Path
import typer
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
from src.aws import s3 as s3_ops
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
app = typer.Typer()
@app.command()
def upload(
path: Path = typer.Argument(..., help="Local file or directory to upload"),
s3_key: str | None = typer.Option(None, "--s3-key", help="S3 key for file uploads"),
config: str = CONFIG_OPT,
) -> None:
"""Upload a local file or directory to S3."""
cfg = load_cfg(config)
if path.is_file():
key = s3_key or f"{cfg.s3.data_prefix.rstrip('/')}/{path.name}"
try:
with CONSOLE.status(f"Uploading {path.name}..."):
uri = s3_ops.upload_file(cfg.aws.region, cfg.aws.profile, cfg.s3.bucket, str(path), key)
except Exception as e:
CONSOLE.print(f"[red]Upload failed: {e}[/red]")
raise typer.Exit(1)
CONSOLE.print(f"[green]✓[/green] {path.name} -> {uri}")
return
if path.is_dir():
if s3_key is not None:
CONSOLE.print("[red]--s3-key can only be used when uploading a single file.[/red]")
raise typer.Exit(1)
files = [file for file in path.rglob("*") if file.is_file()]
if not files:
CONSOLE.print("[yellow]No files found in directory.[/yellow]")
raise typer.Exit(0)
prefix = cfg.s3.data_prefix
CONSOLE.print(f"Uploading {len(files)} files to s3://{cfg.s3.bucket}/{prefix.rstrip('/')}/")
try:
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
console=CONSOLE,
) as progress:
task = progress.add_task("Uploading...", total=len(files))
count = s3_ops.upload_dir(
cfg.aws.region,
cfg.aws.profile,
cfg.s3.bucket,
str(path),
prefix,
on_progress=lambda: progress.advance(task),
)
except Exception as e:
CONSOLE.print(f"[red]Upload failed: {e}[/red]")
raise typer.Exit(1)
CONSOLE.print(f"[green]✓[/green] Uploaded {count} files to s3://{cfg.s3.bucket}/{prefix.rstrip('/')}/")
return
CONSOLE.print(f"[red]Path not found: {path}[/red]")
raise typer.Exit(1)

View File

@@ -14,7 +14,7 @@ def load_config(path: str = "config.yaml") -> Config:
config_path = Path(path)
if not config_path.exists():
raise FileNotFoundError(
f"Config file not found: {config_path}. Run 'qai-cli init' to create one."
f"Config file not found: {config_path}. Run 'qc-cli init' to create one."
)
with open(config_path) as f:
data = yaml.safe_load(f)

View File

@@ -1,30 +1,68 @@
from enum import Enum
from typing import Any, Literal
import re
from enum import StrEnum
from typing import Any, Literal, TypedDict
from mypy_boto3_s3.literals import BucketLocationConstraintType
from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field, field_validator, model_validator
from qai_hub.client import Device
class MlflowMode(str, Enum):
class MlflowMode(StrEnum):
disabled = "disabled"
create = "create"
existing = "existing"
class MlflowServerSize(str, Enum):
class MlflowServerSize(StrEnum):
small = "Small"
medium = "Medium"
large = "Large"
class Boto3SessionKwargs(TypedDict):
profile_name: str
region_name: str
class AwsConfig(BaseModel):
region: BucketLocationConstraintType | Literal["us-east-1"] = "us-east-1"
profile: str = "default"
@property
def boto3_session(self) -> Boto3SessionKwargs:
return {"profile_name": self.profile, "region_name": self.region}
DEFAULT_BOOTSTRAP_QUALIFIER = "hnb659fds"
GENERATED_STACK_PREFIX = "qc-cli-mlops-"
class InfraConfig(BaseModel):
stack_name: str
@property
def effective_bootstrap_qualifier(self) -> str:
sanitized = re.sub(r"[^a-z0-9]", "", self.stack_name.lower())
if not sanitized:
return DEFAULT_BOOTSTRAP_QUALIFIER
if self.stack_name.startswith(GENERATED_STACK_PREFIX):
suffix = re.sub(r"[^a-z0-9]", "", self.stack_name.removeprefix(GENERATED_STACK_PREFIX).lower())
if suffix:
return f"q{suffix}"[:10]
return f"q{sanitized}"[:10]
@property
def effective_toolkit_stack_name(self) -> str:
if self.stack_name.startswith(GENERATED_STACK_PREFIX):
suffix = re.sub(r"[^A-Za-z0-9-]", "", self.stack_name.removeprefix(GENERATED_STACK_PREFIX))
if suffix:
return f"{self.stack_name}-bootstrap"
return f"{self.stack_name}-bootstrap"
class S3Config(BaseModel):
bucket: str = "my-onnx-bucket"
bucket: str = "my-qc-mlops-bucket"
data_prefix: str = "data/"
model_prefix: str = "models/"
@@ -39,13 +77,35 @@ class TrainingConfig(BaseModel):
class SageMakerConfig(BaseModel):
role_name: str = "qai-cli-sagemaker-role"
role_name: str = ""
training: TrainingConfig = Field(default_factory=TrainingConfig)
class AIHubConfig(BaseModel):
device: Device = Field(default_factory=lambda: Device("Samsung Galaxy S25 (Family)"))
target_runtime: str = "tflite"
input_specs: dict[str, tuple[list[int], str]] = Field(default_factory=dict)
job_name: str | None = None
model_name: str | None = None
compile_options: str | None = None
profile_options: str | None = None
quantize_options: str | None = None
output_dir: str = "build/qai-hub"
@field_validator("device", mode="before")
@classmethod
def parse_device(cls, value: Any) -> Any:
if isinstance(value, str):
return Device(value)
return value
class MlflowConfig(BaseModel):
mode: MlflowMode = MlflowMode.disabled
tracking_server_name: str | None = None
experiment_name: str = "qc-cli-training"
registered_model_name: str = "qc-cli-model"
register_trained_models: bool = True
artifact_prefix: str = "mlflow/"
tracking_server_size: MlflowServerSize = MlflowServerSize.small
mlflow_version: str | None = None
@@ -54,13 +114,27 @@ class MlflowConfig(BaseModel):
@model_validator(mode="after")
def require_tracking_server_name(self) -> "MlflowConfig":
if self.mode in {MlflowMode.create, MlflowMode.existing} and not self.tracking_server_name:
raise ValueError("mlflow.tracking_server_name is required when mlflow.mode is create or existing")
if self.mode is MlflowMode.existing and not self.tracking_server_name:
raise ValueError("mlflow.tracking_server_name is required when mlflow.mode is existing")
return self
class Config(BaseModel):
infra: InfraConfig
aws: AwsConfig = Field(default_factory=AwsConfig)
s3: S3Config = Field(default_factory=S3Config)
sagemaker: SageMakerConfig = Field(default_factory=SageMakerConfig)
aihub: AIHubConfig = Field(default_factory=AIHubConfig)
mlflow: MlflowConfig = Field(default_factory=MlflowConfig)
@property
def managed_mlflow_tracking_server_name(self) -> str:
return f"{self.infra.stack_name}-mlflow"
@property
def effective_mlflow_tracking_server_name(self) -> str | None:
if self.mlflow.mode is MlflowMode.disabled:
return None
if self.mlflow.mode is MlflowMode.existing:
return self.mlflow.tracking_server_name
return self.managed_mlflow_tracking_server_name

View File

@@ -5,17 +5,27 @@ from typing import Any
from src.infra.state import state_path, write_infra_state
STACK_NAME = "QaiCliStack"
def bootstrap(
*,
profile: str,
account_id: str,
region: str,
bootstrap_qualifier: str,
toolkit_stack_name: str,
cloudformation_execution_policy: str | None = None,
) -> None:
cmd = ["cdk", "bootstrap", f"aws://{account_id}/{region}", "--profile", profile]
cmd = [
"cdk",
"bootstrap",
f"aws://{account_id}/{region}",
"--profile",
profile,
"--qualifier",
bootstrap_qualifier,
"--toolkit-stack-name",
toolkit_stack_name,
]
if cloudformation_execution_policy:
cmd.extend(["--cloudformation-execution-policies", cloudformation_execution_policy])
_run(cmd)
@@ -26,6 +36,9 @@ def deploy(
profile: str,
account_id: str,
region: str,
stack_name: str,
bootstrap_qualifier: str,
toolkit_stack_name: str,
config_path: str,
config_dir: str,
config_snapshot: dict[str, Any],
@@ -35,19 +48,24 @@ def deploy(
"deploy",
profile=profile,
account_id=account_id,
stack_name=stack_name,
bootstrap_qualifier=bootstrap_qualifier,
toolkit_stack_name=toolkit_stack_name,
config_path=config_path,
delete_bucket_data=False,
) + ["--require-approval", "never", "--outputs-file", str(outputs_file)]
_run(cmd)
outputs = _read_outputs(outputs_file)
outputs = _read_outputs(outputs_file, stack_name)
state = {
"stack_name": STACK_NAME,
"stack_name": stack_name,
"aws": {
"account_id": account_id,
"region": region,
"profile": profile,
},
"bootstrap_qualifier": bootstrap_qualifier,
"toolkit_stack_name": toolkit_stack_name,
"config": config_snapshot,
"outputs": outputs,
}
@@ -59,6 +77,9 @@ def destroy(
*,
profile: str,
account_id: str,
stack_name: str,
bootstrap_qualifier: str,
toolkit_stack_name: str,
config_path: str,
delete_bucket_data: bool,
) -> None:
@@ -67,6 +88,9 @@ def destroy(
"deploy",
profile=profile,
account_id=account_id,
stack_name=stack_name,
bootstrap_qualifier=bootstrap_qualifier,
toolkit_stack_name=toolkit_stack_name,
config_path=config_path,
delete_bucket_data=True,
) + ["--require-approval", "never"]
@@ -76,6 +100,9 @@ def destroy(
"destroy",
profile=profile,
account_id=account_id,
stack_name=stack_name,
bootstrap_qualifier=bootstrap_qualifier,
toolkit_stack_name=toolkit_stack_name,
config_path=config_path,
delete_bucket_data=delete_bucket_data,
) + ["--force"]
@@ -87,26 +114,35 @@ def _cdk_cmd(
*,
profile: str,
account_id: str,
stack_name: str,
bootstrap_qualifier: str,
toolkit_stack_name: str,
config_path: str,
delete_bucket_data: bool,
) -> list[str]:
cmd = [
"cdk",
action,
STACK_NAME,
stack_name,
"--app",
"python app.py",
"--profile",
profile,
]
if action == "deploy":
cmd.extend(["--toolkit-stack-name", toolkit_stack_name])
cmd.extend([
"-c",
f"account_id={account_id}",
"-c",
f"config={config_path}",
"-c",
f"stack_name={STACK_NAME}",
f"stack_name={stack_name}",
"-c",
f"bootstrap_qualifier={bootstrap_qualifier}",
"-c",
f"delete_bucket_data={str(delete_bucket_data).lower()}",
]
])
return cmd
@@ -119,9 +155,9 @@ def _run(cmd: list[str]) -> None:
raise RuntimeError(f"CDK command failed with exit code {e.returncode}.") from e
def _read_outputs(path: Path) -> dict[str, str]:
def _read_outputs(path: Path, stack_name: str) -> dict[str, str]:
if not path.exists():
return {}
with open(path) as f:
data = json.load(f)
return data.get(STACK_NAME, {})
return data.get(stack_name, {})

View File

@@ -9,7 +9,7 @@ from constructs import Construct
from src.config import Config, MlflowMode
class QaiStack(Stack):
class QCStack(Stack):
def __init__(
self,
scope: Construct,
@@ -34,7 +34,7 @@ class QaiStack(Stack):
role = iam.CfnRole(
self,
"SageMakerRole",
role_name=config.sagemaker.role_name,
role_name=config.sagemaker.role_name or None,
assume_role_policy_document=self._sagemaker_trust_policy(),
managed_policy_arns=[
f"arn:{self.partition}:iam::aws:policy/AmazonSageMakerFullAccess",
@@ -74,6 +74,7 @@ class QaiStack(Stack):
CfnOutput(self, "SageMakerRoleArn", value=role.attr_arn)
if config.mlflow.mode is MlflowMode.create:
tracking_server_name = config.managed_mlflow_tracking_server_name
artifact_prefix = config.mlflow.artifact_prefix.strip("/")
artifact_uri = (
f"s3://{data_bucket.bucket_name}/{artifact_prefix}/"
@@ -145,14 +146,14 @@ class QaiStack(Stack):
"MlflowTrackingServer",
artifact_store_uri=artifact_uri,
role_arn=mlflow_role.attr_arn,
tracking_server_name=config.mlflow.tracking_server_name or "",
tracking_server_name=tracking_server_name,
automatic_model_registration=config.mlflow.automatic_model_registration,
mlflow_version=config.mlflow.mlflow_version,
tracking_server_size=config.mlflow.tracking_server_size.value,
weekly_maintenance_window_start=config.mlflow.weekly_maintenance_window_start,
)
CfnOutput(self, "MlflowTrackingServerName", value=config.mlflow.tracking_server_name or "")
CfnOutput(self, "MlflowTrackingServerName", value=tracking_server_name)
CfnOutput(self, "MlflowTrackingServerArn", value=tracking_server.attr_tracking_server_arn)
CfnOutput(self, "MlflowArtifactUri", value=artifact_uri)
CfnOutput(self, "MlflowRoleArn", value=mlflow_role.attr_arn)

View File

@@ -2,7 +2,7 @@ import json
from pathlib import Path
from typing import Any
INFRA_STATE_FILE = ".qai-cli-infra.json"
INFRA_STATE_FILE = ".qc-cli-infra.json"
def state_path(config_dir: str) -> Path:

View File

@@ -1,39 +1,14 @@
from pathlib import Path
import typer
import yaml
from rich.console import Console
from src.commands import infra
from src.config import Config
from src.commands import ai_hub, infra, init, mlflow, train, upload
app = typer.Typer(
help="qai-cli: End-to-end model managment for Qualcomm AI Hub.",
help="qc-cli: End-to-end model managment for Qualcomm AI Hub.",
no_args_is_help=True,
)
app.add_typer(init.app)
app.add_typer(upload.app)
app.add_typer(mlflow.app, name="mlflow")
app.add_typer(infra.app, name="infra")
console = Console()
@app.command()
def init(
output: str = typer.Option("config.yaml", help="Destination path for the config file"),
force: bool = typer.Option(False, "--force", "-f", help="Overwrite an existing config file"),
) -> None:
"""Write a starter config.yaml to the current directory."""
dest = Path(output)
if dest.exists() and not force:
console.print(f"[yellow]{dest} already exists.[/yellow] Use --force to overwrite.")
raise typer.Exit(1)
config = Config()
dest.parent.mkdir(parents=True, exist_ok=True)
with open(dest, "w") as f:
yaml.safe_dump(config.model_dump(mode="json"), f, sort_keys=False)
console.print(f"[green]✓[/green] Config written to [bold]{dest}[/bold]")
console.print(
"Edit it (especially [cyan]s3.bucket[/cyan] and [cyan]sagemaker.training.image_uri[/cyan]) "
"before running other commands."
)
app.add_typer(train.app, name="train")
app.add_typer(ai_hub.app, name="ai-hub")

0
src/qualcomm/__init__.py Normal file
View File

114
src/qualcomm/aihub_jobs.py Normal file
View File

@@ -0,0 +1,114 @@
from pathlib import Path
from typing import Any, TypedDict
import qai_hub.hub as hub
from qai_hub.client import CompileJob, Device, InferenceJob, Model, ProfileJob, QuantizeDtype, QuantizeJob
class ModelJobResult(TypedDict):
job: CompileJob | QuantizeJob
job_id: str
model: Model
model_id: str
class InferenceJobResult(TypedDict):
job: InferenceJob
job_id: str
outputs: Any
class ProfileJobResult(TypedDict):
job: ProfileJob
job_id: str
def _dataset_entries(inputs: dict[str, Any]) -> dict[str, list[Any]]:
return {name: value if isinstance(value, list) else [value] for name, value in inputs.items()}
def submit_compile_job(
model: Model,
device: Device,
input_specs: dict[str, tuple[tuple[int, ...], str]],
target_runtime: str,
options: str | None = None,
job_name: str | None = None,
) -> ModelJobResult:
compile_options = f"--target_runtime {target_runtime}"
if options:
compile_options = f"{compile_options} {options}"
job = hub.submit_compile_job(
model=model,
device=device,
name=job_name,
input_specs=input_specs,
options=compile_options,
)
target_model = job.get_target_model()
if target_model is None:
raise RuntimeError(f"Compile job {job.job_id} did not produce a target model.")
return {"job": job, "job_id": str(job.job_id), "model": target_model, "model_id": str(target_model.model_id)}
def submit_inference_job(
model: Model,
device: Device,
inputs: dict[str, Any],
output_dir: str | Path,
job_name: str | None = None,
) -> InferenceJobResult:
job = hub.submit_inference_job(
model=model,
device=device,
inputs=_dataset_entries(inputs),
name=job_name,
)
out = Path(output_dir)
out.mkdir(parents=True, exist_ok=True)
data = job.download_output_data(str(out))
return {"job": job, "job_id": str(job.job_id), "outputs": data}
def submit_profile_job(
model: Model,
device: Device,
options: str | None = None,
job_name: str | None = None,
) -> ProfileJobResult:
job = hub.submit_profile_job(
model=model,
device=device,
name=job_name,
options=options or "",
)
return {"job": job, "job_id": str(job.job_id)}
def submit_quantize_job(
model: Model,
calibration_data: dict[str, Any],
options: str | None = None,
job_name: str | None = None,
) -> ModelJobResult:
job = hub.submit_quantize_job(
model=model,
calibration_data=_dataset_entries(calibration_data),
weights_dtype=QuantizeDtype.INT8,
activations_dtype=QuantizeDtype.INT8,
name=job_name,
options=options or "",
)
target_model = job.get_target_model()
if target_model is None:
raise RuntimeError(f"Quantize job {job.job_id} did not produce a target model.")
return {"job": job, "job_id": str(job.job_id), "model": target_model, "model_id": str(target_model.model_id)}
def download_model(model_id: str, output_path: str | Path) -> str:
dest = Path(output_path)
dest.parent.mkdir(parents=True, exist_ok=True)
model = hub.get_model(model_id)
result = model.download(str(dest))
return str(result or dest)

83
src/qualcomm/artifacts.py Normal file
View File

@@ -0,0 +1,83 @@
import tarfile
from dataclasses import dataclass
from pathlib import Path
from src.aws import s3 as s3_ops
from src.aws import sagemaker as sm_ops
from src.config import Config
@dataclass(frozen=True)
class ResolvedOnnx:
onnx_path: Path
model_artifact: str | None
run_name: str
def _safe_extract(tar: tarfile.TarFile, dest: Path) -> None:
dest_root = dest.resolve()
for member in tar.getmembers():
target = (dest / member.name).resolve()
if dest_root != target and dest_root not in target.parents:
raise ValueError(f"Unsafe tar member path: {member.name}")
tar.extractall(dest, filter="data")
def _find_onnx(root: Path, explicit: str | None = None) -> Path:
if explicit:
p = Path(explicit)
if not p.is_absolute():
p = root / p
if not p.exists():
raise FileNotFoundError(f"ONNX file not found: {p}")
return p
matches = sorted(root.rglob("model.onnx"))
if not matches:
matches = sorted(root.rglob("*.onnx"))
if not matches:
raise FileNotFoundError(f"No ONNX file found under {root}")
if len(matches) > 1:
joined = ", ".join(str(p.relative_to(root)) for p in matches)
raise ValueError(f"Multiple ONNX files found ({joined}). Pass --onnx-path.")
return matches[0]
def resolve_onnx(
cfg: Config,
output_dir: str,
from_job: str | None = None,
model_s3_uri: str | None = None,
onnx_path: str | None = None,
last_training_job: str | None = None,
) -> ResolvedOnnx:
if onnx_path:
path = Path(onnx_path)
if path.exists():
return ResolvedOnnx(onnx_path=path, model_artifact=None, run_name=path.stem)
job = from_job or last_training_job
artifact = model_s3_uri
if not artifact:
if not job:
raise ValueError("No model source found. Pass --onnx-path, --model-s3-uri, --from-job, or run training first.")
artifact = sm_ops.get_model_artifacts(cfg.aws.region, cfg.aws.profile, job)
run_name = job or Path(artifact).name.removesuffix(".tar.gz").replace("/", "-")
root = Path(output_dir) / run_name / "source"
tar_path = root / "model.tar.gz"
s3_ops.download_file(cfg.aws.region, cfg.aws.profile, artifact, str(tar_path))
extract_dir = root / "extracted"
extract_dir.mkdir(parents=True, exist_ok=True)
try:
with tarfile.open(tar_path, "r:gz") as tar:
_safe_extract(tar, extract_dir)
except tarfile.TarError as e:
raise ValueError(f"Invalid model tarball: {tar_path}") from e
return ResolvedOnnx(
onnx_path=_find_onnx(extract_dir, onnx_path),
model_artifact=artifact,
run_name=run_name,
)

85
src/state.py Normal file
View File

@@ -0,0 +1,85 @@
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any
STATE_FILE = ".qc-cli.json"
@dataclass(frozen=True)
class CliStateStore:
config_dir: str = "."
@property
def path(self) -> Path:
return Path(self.config_dir) / STATE_FILE
def read(self) -> dict[str, Any]:
if not self.path.exists():
return {}
with open(self.path) as f:
value = json.load(f)
return dict(value) if isinstance(value, dict) else {}
def update(self, **updates: Any) -> None:
state = self.read()
state.update(updates)
self._write(state)
def get(self, key: str, default: Any = None) -> Any:
return self.read().get(key, default)
def get_last_training_job(self) -> str | None:
value = self.get("last_training_job")
return str(value) if value else None
def get_last_model_artifact(self) -> str | None:
value = self.get("last_model_artifact")
return str(value) if value else None
def get_last_optimized_model_id(self) -> str | None:
value = self.get("last_optimized_model_id")
return str(value) if value else None
def get_last_quantized_model_id(self) -> str | None:
value = self.get("last_quantized_model_id")
return str(value) if value else None
def get_last_compiled_model_id(self) -> str | None:
value = self.get("last_compiled_model_id")
return str(value) if value else None
def get_last_downloaded_model(self) -> str | None:
value = self.get("last_downloaded_model")
return str(value) if value else None
def set_last_training_job(self, job_name: str) -> None:
self.update(last_training_job=job_name)
def get_training_job(self, job_name: str) -> dict[str, Any]:
jobs = self._training_jobs(self.read())
value = jobs.get(job_name, {})
return dict(value) if isinstance(value, dict) else {}
def update_training_job(self, job_name: str, **updates: Any) -> None:
state = self.read()
jobs = self._training_jobs(state)
jobs[job_name] = {**jobs.get(job_name, {}), **updates}
state["training_jobs"] = jobs
self._write(state)
def set_latest_experiment_model_version(self, version: str) -> None:
self.update(latest_experiment_model_version=version)
def _write(self, state: dict[str, Any]) -> None:
with open(self.path, "w") as f:
json.dump(state, f, indent=2)
def _training_jobs(self, state: dict[str, Any]) -> dict[str, Any]:
value = state.get("training_jobs", {})
return dict(value) if isinstance(value, dict) else {}
def store(config_path: str) -> CliStateStore:
config_dir = str(Path(config_path).parent)
return CliStateStore(config_dir)

3
src/tracking/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from src.tracking.mlflow import FinalizeResult, MlflowTracker, NoopTracker, Tracker
__all__ = ["FinalizeResult", "MlflowTracker", "NoopTracker", "Tracker"]

93
src/tracking/metrics.py Normal file
View File

@@ -0,0 +1,93 @@
import json
import math
import tarfile
from dataclasses import dataclass
from pathlib import PurePosixPath
from typing import Any
METRICS_ARTIFACT_NAME = "training_metrics.json"
METRICS_SCHEMA_VERSION = 1
MAX_METRICS_ARTIFACT_BYTES = 10 * 1024 * 1024
@dataclass(frozen=True)
class MetricStep:
step: int
metrics: dict[str, float]
@dataclass(frozen=True)
class TrainingMetrics:
steps: list[MetricStep]
summary: dict[str, float]
raw: dict[str, Any]
def parse_training_metrics(data: bytes) -> TrainingMetrics:
try:
value = json.loads(data)
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
raise ValueError(f"{METRICS_ARTIFACT_NAME} is not valid JSON") from exc
if not isinstance(value, dict):
raise ValueError(f"{METRICS_ARTIFACT_NAME} must contain a JSON object")
if value.get("schema_version") != METRICS_SCHEMA_VERSION:
raise ValueError(f"Unsupported training metrics schema version: {value.get('schema_version')!r}")
raw_steps = value.get("steps")
if not isinstance(raw_steps, list):
raise ValueError("training metrics 'steps' must be a list")
steps: list[MetricStep] = []
previous_step: int | None = None
for index, raw_step in enumerate(raw_steps):
if not isinstance(raw_step, dict):
raise ValueError(f"training metrics step {index} must be an object")
step = raw_step.get("step")
if isinstance(step, bool) or not isinstance(step, int) or step < 0:
raise ValueError(f"training metrics step {index} has an invalid 'step'")
if previous_step is not None and step <= previous_step:
raise ValueError("training metrics steps must be unique and strictly increasing")
metrics = _numeric_metrics(raw_step.get("metrics"), f"training metrics step {step}")
steps.append(MetricStep(step=step, metrics=metrics))
previous_step = step
summary = _numeric_metrics(value.get("summary", {}), "training metrics summary")
return TrainingMetrics(steps=steps, summary=summary, raw=value)
def read_training_metrics_from_tar(archive_path: str) -> bytes | None:
with tarfile.open(archive_path, mode="r:*") as archive:
matches = [
member
for member in archive.getmembers()
if member.isfile() and PurePosixPath(member.name).name == METRICS_ARTIFACT_NAME
]
if not matches:
return None
if len(matches) > 1:
raise ValueError(f"Model archive contains multiple {METRICS_ARTIFACT_NAME} files")
if matches[0].size > MAX_METRICS_ARTIFACT_BYTES:
raise ValueError(
f"{METRICS_ARTIFACT_NAME} exceeds the {MAX_METRICS_ARTIFACT_BYTES}-byte size limit"
)
extracted = archive.extractfile(matches[0])
if extracted is None:
raise ValueError(f"Could not read {METRICS_ARTIFACT_NAME} from model archive")
return extracted.read()
def _numeric_metrics(value: Any, context: str) -> dict[str, float]:
if not isinstance(value, dict):
raise ValueError(f"{context} 'metrics' must be an object")
metrics: dict[str, float] = {}
for raw_name, raw_value in value.items():
if not isinstance(raw_name, str) or not raw_name:
raise ValueError(f"{context} contains an invalid metric name")
if isinstance(raw_value, bool) or not isinstance(raw_value, int | float):
raise ValueError(f"{context} metric '{raw_name}' must be numeric")
metric_value = float(raw_value)
if not math.isfinite(metric_value):
raise ValueError(f"{context} metric '{raw_name}' must be finite")
metrics[raw_name] = metric_value
return metrics

267
src/tracking/mlflow.py Normal file
View File

@@ -0,0 +1,267 @@
import os
import tempfile
from dataclasses import dataclass
from typing import Any, Protocol
import mlflow
from mlflow.tracking import MlflowClient
from src.aws import s3
from src.cloud.mlflow import MlflowTrackingBackend, mlflow_tracking_backend_from_config
from src.config import Config, MlflowMode
from src.tracking.metrics import METRICS_ARTIFACT_NAME, parse_training_metrics, read_training_metrics_from_tar
@dataclass(frozen=True)
class FinalizeResult:
registered_model_version: str | None = None
warnings: tuple[str, ...] = ()
class Tracker(Protocol):
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: ...
def finalize_training_run(
self,
*,
run_id: str | None,
training_job_status: Any,
region: str,
profile: str,
command: str,
) -> FinalizeResult: ...
def ensure_training_run(self, job_name: str) -> str: ...
def upload_training_metrics(
self,
*,
run_id: str,
training_job_status: Any,
region: str,
profile: str,
) -> bool: ...
@dataclass(frozen=True)
class NoopTracker:
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
return None
def finalize_training_run(
self,
*,
run_id: str | None,
training_job_status: Any,
region: str,
profile: str,
command: str,
) -> FinalizeResult:
return FinalizeResult()
def ensure_training_run(self, job_name: str) -> str:
raise RuntimeError("MLflow is disabled.")
def upload_training_metrics(
self,
*,
run_id: str,
training_job_status: Any,
region: str,
profile: str,
) -> bool:
raise RuntimeError("MLflow is disabled.")
@dataclass(frozen=True)
class MlflowTracker:
tracking_uri: str
experiment_name: str
registered_model_name: str
register_trained_models: bool
tracking_backend: MlflowTrackingBackend
@classmethod
def from_config(cls, cfg: Config) -> Tracker:
if cfg.mlflow.mode is MlflowMode.disabled:
return NoopTracker()
os.environ.setdefault("MLFLOW_SUPPRESS_PRINTING_URL_TO_STDOUT", "true")
tracking_server_name = cfg.effective_mlflow_tracking_server_name
if not tracking_server_name:
raise RuntimeError("MLflow tracking server name could not be resolved.")
tracking_backend = mlflow_tracking_backend_from_config(cfg)
tracking_uri = tracking_backend.get_tracking_uri(tracking_server_name)
with tracking_backend.auth_env():
mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment(cfg.mlflow.experiment_name)
return cls(
tracking_uri=tracking_uri,
experiment_name=cfg.mlflow.experiment_name,
registered_model_name=cfg.mlflow.registered_model_name,
register_trained_models=cfg.mlflow.register_trained_models,
tracking_backend=tracking_backend,
)
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
with self.tracking_backend.auth_env():
with mlflow.start_run(run_name=training_job.job_name) as run:
run_id = str(run.info.run_id)
self._log_params(
self.tracking_backend.training_run_params(
training_job,
region=region,
profile=profile,
role_arn=role_arn,
)
)
self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()})
mlflow.set_tags(
{
"qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source",
"qc_cli.source": self.tracking_backend.provider_name,
"qc_cli.command": "train start",
**self.tracking_backend.training_run_tags(training_job),
}
)
return run_id
def finalize_training_run(
self,
*,
run_id: str | None,
training_job_status: Any,
region: str,
profile: str,
command: str,
) -> FinalizeResult:
if not run_id:
return FinalizeResult()
with self.tracking_backend.auth_env():
with mlflow.start_run(run_id=run_id):
self._log_params(self.tracking_backend.training_status_params(training_job_status))
self._log_final_metrics(training_job_status.raw)
mlflow.set_tag("qc_cli.command", command)
if training_job_status.status != "Completed" or not training_job_status.model_artifacts:
mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status)
return FinalizeResult()
if not self.register_trained_models:
return FinalizeResult()
client = MlflowClient()
self._ensure_registered_model(client, self.registered_model_name)
version = client.create_model_version(
name=self.registered_model_name,
source=training_job_status.model_artifacts,
run_id=run_id,
tags={
"qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source",
"qc_cli.source": self.tracking_backend.provider_name,
**self.tracking_backend.model_version_tags(training_job_status),
},
)
version_number = str(version.version)
client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number)
mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name)
mlflow.set_tag("qc_cli.registered_model_version", version_number)
return FinalizeResult(registered_model_version=version_number)
def ensure_training_run(self, job_name: str) -> str:
with self.tracking_backend.auth_env():
client = MlflowClient()
experiment = client.get_experiment_by_name(self.experiment_name)
if experiment is None:
experiment_id = mlflow.create_experiment(self.experiment_name)
else:
experiment_id = experiment.experiment_id
for run in client.search_runs([experiment_id], max_results=1000):
if run.data.tags.get("sagemaker.job_name") == job_name:
return str(run.info.run_id)
run = client.create_run(
experiment_id,
run_name=job_name,
tags={
"qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source",
"qc_cli.source": self.tracking_backend.provider_name,
"qc_cli.command": "mlflow upload-metrics",
"sagemaker.job_name": job_name,
},
)
return str(run.info.run_id)
def upload_training_metrics(
self,
*,
run_id: str,
training_job_status: Any,
region: str,
profile: str,
) -> bool:
if not training_job_status.model_artifacts:
raise ValueError(f"Training job '{training_job_status.name}' has no model artifacts.")
with self.tracking_backend.auth_env():
with mlflow.start_run(run_id=run_id):
self._log_params(self.tracking_backend.training_status_params(training_job_status))
self._log_final_metrics(training_job_status.raw)
history_uploaded = self._log_training_metrics(
training_job_status.model_artifacts,
region=region,
profile=profile,
)
mlflow.set_tag("qc_cli.command", "mlflow upload-metrics")
mlflow.set_tag("qc_cli.metrics_history_uploaded", str(history_uploaded).lower())
return history_uploaded
def _log_params(self, params: dict[str, Any]) -> None:
cleaned = {key: str(value) for key, value in params.items() if value is not None}
if cleaned:
mlflow.log_params(cleaned)
def _log_final_metrics(self, training_job: dict[str, Any]) -> None:
metrics = {}
for metric in training_job.get("FinalMetricDataList", []):
name = metric.get("MetricName")
value = metric.get("Value")
if name and value is not None:
metrics[str(name)] = float(value)
if metrics:
mlflow.log_metrics(metrics)
def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> bool:
with tempfile.TemporaryDirectory(prefix="qc-cli-metrics-") as temp_dir:
archive_path = s3.download_file(
region,
profile,
model_artifacts,
os.path.join(temp_dir, "model.tar.gz"),
)
metrics_data = read_training_metrics_from_tar(archive_path)
if metrics_data is None:
return False
metrics = parse_training_metrics(metrics_data)
for metric_step in metrics.steps:
if metric_step.metrics:
mlflow.log_metrics(metric_step.metrics, step=metric_step.step)
if metrics.summary:
mlflow.log_metrics(metrics.summary)
mlflow.log_dict(metrics.raw, METRICS_ARTIFACT_NAME)
return True
def _ensure_registered_model(self, client: MlflowClient, name: str) -> None:
try:
client.get_registered_model(name)
except Exception:
client.create_registered_model(name)

75
src/tracking/upload.py Normal file
View File

@@ -0,0 +1,75 @@
from dataclasses import dataclass
from src import state as state_ops
from src.aws import sagemaker as sm_ops
from src.config import Config, MlflowMode
from src.tracking.mlflow import MlflowTracker
@dataclass(frozen=True)
class MetricsUploadResult:
run_id: str
registered_model_version: str | None = None
metrics_history_uploaded: bool = True
def upload_training_metrics(
*,
job_name: str,
config_path: str,
cfg: Config,
force: bool = False,
) -> MetricsUploadResult:
if cfg.mlflow.mode is MlflowMode.disabled:
raise RuntimeError("MLflow is disabled in config.yaml.")
st = state_ops.store(config_path)
job_state = st.get_training_job(job_name)
if job_state.get("mlflow_metrics_uploaded") and not force:
return MetricsUploadResult(
run_id=str(job_state.get("mlflow_run_id") or ""),
registered_model_version=(
str(job_state["registered_model_version"])
if job_state.get("registered_model_version")
else None
),
metrics_history_uploaded=bool(job_state.get("mlflow_metrics_history_uploaded", True)),
)
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
if status.status != "Completed":
raise RuntimeError(
f"Training job {job_name} is {status.status}; metrics can be uploaded only after completion."
)
tracker = MlflowTracker.from_config(cfg)
run_id = str(job_state.get("mlflow_run_id") or tracker.ensure_training_run(job_name))
st.update_training_job(job_name, mlflow_run_id=run_id)
metrics_history_uploaded = tracker.upload_training_metrics(
run_id=run_id,
training_job_status=status,
region=cfg.aws.region,
profile=cfg.aws.profile,
)
finalized = tracker.finalize_training_run(
run_id=run_id,
training_job_status=status,
region=cfg.aws.region,
profile=cfg.aws.profile,
command="mlflow upload-metrics",
)
updates = {
"mlflow_metrics_uploaded": True,
"mlflow_metrics_history_uploaded": metrics_history_uploaded,
"mlflow_finalized_status": status.status,
}
if finalized.registered_model_version:
updates["registered_model_version"] = finalized.registered_model_version
st.update_training_job(job_name, **updates)
if finalized.registered_model_version:
st.set_latest_experiment_model_version(finalized.registered_model_version)
return MetricsUploadResult(
run_id=run_id,
registered_model_version=finalized.registered_model_version,
metrics_history_uploaded=metrics_history_uploaded,
)

1972
uv.lock generated

File diff suppressed because it is too large Load Diff