Compare commits
17 Commits
2618b30d40
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| b56b77330c | |||
| a1ffbb77c5 | |||
| 522ddc74e2 | |||
|
|
5360a482fc | ||
|
|
6a560a8610 | ||
| d244150d98 | |||
| d7c7158464 | |||
| 6bc25dc183 | |||
|
|
71a95aa3a7 | ||
| a3f3060e13 | |||
| e9ada2612f | |||
| 6ac9702dc5 | |||
| 0e728cc193 | |||
| 62ffe163e8 | |||
|
|
cfc04b473f | ||
| 717257dd75 | |||
| 75255b37d0 |
223
.gitignore
vendored
223
.gitignore
vendored
@@ -1,5 +1,224 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
*.py[codz]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py.cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
# Pipfile.lock
|
||||||
|
|
||||||
|
# UV
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# uv.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
# poetry.lock
|
||||||
|
# poetry.toml
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
||||||
|
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
||||||
|
# pdm.lock
|
||||||
|
# pdm.toml
|
||||||
|
.pdm-python
|
||||||
|
.pdm-build/
|
||||||
|
|
||||||
|
# pixi
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
||||||
|
# pixi.lock
|
||||||
|
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
||||||
|
# in the .venv directory. It is recommended not to include this directory in version control.
|
||||||
|
.pixi
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# Redis
|
||||||
|
*.rdb
|
||||||
|
*.aof
|
||||||
|
*.pid
|
||||||
|
|
||||||
|
# RabbitMQ
|
||||||
|
mnesia/
|
||||||
|
rabbitmq/
|
||||||
|
rabbitmq-data/
|
||||||
|
|
||||||
|
# ActiveMQ
|
||||||
|
activemq-data/
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.envrc
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
# .idea/
|
||||||
|
|
||||||
|
# Abstra
|
||||||
|
# Abstra is an AI-powered process automation framework.
|
||||||
|
# Ignore directories containing user credentials, local state, and settings.
|
||||||
|
# Learn more at https://abstra.io/docs
|
||||||
|
.abstra/
|
||||||
|
|
||||||
|
# Visual Studio Code
|
||||||
|
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
||||||
|
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
||||||
|
# you could uncomment the following to ignore the entire vscode folder
|
||||||
|
# .vscode/
|
||||||
|
# Temporary file for partial code execution
|
||||||
|
tempCodeRunnerFile.py
|
||||||
|
|
||||||
|
# Ruff stuff:
|
||||||
|
.ruff_cache/
|
||||||
|
|
||||||
|
# PyPI configuration file
|
||||||
|
.pypirc
|
||||||
|
|
||||||
|
# Marimo
|
||||||
|
marimo/_static/
|
||||||
|
marimo/_lsp/
|
||||||
|
__marimo__/
|
||||||
|
|
||||||
|
# Streamlit
|
||||||
|
.streamlit/secrets.toml
|
||||||
|
|
||||||
.venv/
|
.venv/
|
||||||
config.yaml
|
config*.yaml
|
||||||
cdk.out/
|
cdk.out/
|
||||||
.qai-cli-infra*
|
.qc-cli*.json
|
||||||
|
examples/*/data/
|
||||||
|
|||||||
220
README.md
220
README.md
@@ -1,66 +1,102 @@
|
|||||||
# qai-cli
|
# qc-cli
|
||||||
|
|
||||||
A CLI for the Qualcomm model MLOps pipeline — browse and download models from Qualcomm AI Hub, fine-tune them on custom datasets using SageMaker, validate inference, and prepare artifacts for Qualcomm hardware deployment.
|
A CLI for Qualcomm's MLOps pipeline — browse and download models from Qualcomm AI Hub, fine-tune them on custom datasets using SageMaker, validate inference, and prepare artifacts for Qualcomm hardware deployment.
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
- Python 3.13+
|
- Python 3.13+
|
||||||
- [uv](https://docs.astral.sh/uv/getting-started/installation/)
|
- [uv](https://docs.astral.sh/uv/getting-started/installation/)
|
||||||
- AWS account with credentials configured (`aws configure`) when using `qai-cli infra`
|
- AWS account with credentials configured (`aws configure`) when using `qc-cli infra`
|
||||||
- AWS CDK CLI (`npm install -g aws-cdk`) when using `qai-cli infra setup` or `qai-cli infra destroy`
|
- AWS CDK CLI (`npm install -g aws-cdk`) when using `qc-cli infra setup` or `qc-cli infra destroy`
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone <repo>
|
git clone <repo>
|
||||||
cd qai-cli
|
cd qc-cli
|
||||||
uv sync
|
uv sync
|
||||||
```
|
```
|
||||||
|
|
||||||
Run commands with `uv run qai-cli <command>` or activate the venv first:
|
Run commands with `uv run qc-cli <command>` or activate the venv first:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
qai-cli --help
|
qc-cli --help
|
||||||
```
|
```
|
||||||
|
|
||||||
## Quick start
|
## Quick start
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 1. Create config.yaml in the current directory
|
# 1. Create config.yaml in the current directory
|
||||||
qai-cli init
|
qc-cli init
|
||||||
|
|
||||||
# 2. Edit config.yaml — at minimum set s3.bucket and sagemaker.role_name
|
# 2. Edit config.yaml — at minimum set sagemaker.training.image_uri
|
||||||
|
|
||||||
# 3. Provision AWS infrastructure (S3 bucket + SageMaker IAM role).
|
# 3. Provision AWS infrastructure (S3 bucket + SageMaker IAM role).
|
||||||
# This is the step that requires the AWS CDK CLI.
|
# This is the step that requires the AWS CDK CLI.
|
||||||
qai-cli infra setup
|
qc-cli infra setup
|
||||||
|
|
||||||
|
# 4. Upload training data, then submit a SageMaker training job.
|
||||||
|
qc-cli upload ./my-dataset
|
||||||
|
qc-cli train start
|
||||||
|
qc-cli train status
|
||||||
```
|
```
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
`qai-cli init` writes a `config.yaml` in the current directory. The fields you must fill in before using the tool:
|
`qc-cli init` writes a `config.yaml` in the current directory. The fields you must fill in before using the tool:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
infra:
|
||||||
|
stack_name: qc-cli-mlops-1a2b3c4d5e6f
|
||||||
|
|
||||||
aws:
|
aws:
|
||||||
region: us-east-1
|
region: us-east-1
|
||||||
profile: default # AWS CLI profile name
|
profile: default # AWS CLI profile name
|
||||||
|
|
||||||
s3:
|
s3:
|
||||||
bucket: your-unique-bucket-name
|
bucket: qc-cli-mlops-1a2b3c4d5e6f-data
|
||||||
|
|
||||||
sagemaker:
|
sagemaker:
|
||||||
role_name: qai-cli-sagemaker-role
|
training:
|
||||||
|
image_uri: "" # ECR URI for your training container
|
||||||
|
instance_type: ml.m5.xlarge
|
||||||
|
instance_count: 1
|
||||||
|
entry_point: null # Optional: script inside source_dir
|
||||||
|
source_dir: null # Optional: local dir packaged and uploaded automatically
|
||||||
|
hyperparameters: {}
|
||||||
|
|
||||||
|
aihub:
|
||||||
|
device:
|
||||||
|
name: Samsung Galaxy S25 (Family)
|
||||||
|
target_runtime: tflite
|
||||||
|
input_specs: {} # Required before running qc-cli ai-hub commands
|
||||||
|
job_name: null # Optional prefix for AI Hub Workbench jobs
|
||||||
|
model_name: null # Optional name for uploaded local ONNX models
|
||||||
|
compile_options: null
|
||||||
|
profile_options: null
|
||||||
|
quantize_options: null
|
||||||
|
output_dir: build/qai-hub
|
||||||
```
|
```
|
||||||
|
|
||||||
|
`qc-cli init` generates the `infra.stack_name` and `s3.bucket` namespace once and writes it to `config.yaml`. Keep these values stable for a deployment; changing them points the CLI at different infrastructure.
|
||||||
|
|
||||||
|
The CLI isolates both application resources and CDK bootstrap resources. The application CloudFormation stack uses `infra.stack_name`, the S3 bucket uses the same generated namespace because bucket names are globally unique, and the SageMaker IAM role uses a CloudFormation-generated physical name. CDK bootstrap resources are derived internally from `infra.stack_name`, including a bootstrap stack named `<stack_name>-bootstrap` and a matching non-default CDK asset bucket qualifier. `qc-cli infra destroy` removes the application stack but leaves the CDK bootstrap stack in place; the command prints the retained bootstrap stack name.
|
||||||
|
|
||||||
|
`hyperparameters` is a flat map of values passed to the training container. Valid keys depend on the selected training image and entry point.
|
||||||
|
|
||||||
To provision an MLflow tracking server, set:
|
To provision an MLflow tracking server, set:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
mlflow:
|
mlflow:
|
||||||
mode: create
|
mode: create
|
||||||
tracking_server_name: your-tracking-server-name
|
experiment_name: qc-cli-training
|
||||||
|
registered_model_name: qc-cli-model
|
||||||
|
register_trained_models: true
|
||||||
```
|
```
|
||||||
|
|
||||||
|
In `create` mode, the CLI manages the tracking server name from `infra.stack_name`; you do not need to set `tracking_server_name`.
|
||||||
|
|
||||||
To use an existing MLflow tracking server, set:
|
To use an existing MLflow tracking server, set:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
@@ -69,28 +105,163 @@ mlflow:
|
|||||||
tracking_server_name: your-tracking-server-name
|
tracking_server_name: your-tracking-server-name
|
||||||
```
|
```
|
||||||
|
|
||||||
|
When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. Training metrics can be upload with `train start --upload-metrics` or `mlflow upload-metrics`.
|
||||||
|
|
||||||
|
To open the managed SageMaker MLflow UI, request a fresh presigned URL:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
qc-cli mlflow open --config config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
This opens a browser to a fresh presigned URL. It works for `mode: create` and for `mode: existing` when the existing server is managed by Amazon SageMaker. In `create` mode, the command uses the CLI-managed tracking server name. In `existing` mode, it uses `mlflow.tracking_server_name`. If the existing MLflow server is external to SageMaker, open it with that server's own URL instead.
|
||||||
|
|
||||||
## Commands
|
## Commands
|
||||||
|
|
||||||
### `init`
|
### `init`
|
||||||
|
|
||||||
```
|
```
|
||||||
qai-cli init Write config.yaml
|
qc-cli init Write config.yaml
|
||||||
qai-cli init --output <path> Write config to a custom path
|
qc-cli init --output <path> Write config to a custom path
|
||||||
qai-cli init --force Overwrite an existing config file
|
qc-cli init --force Overwrite an existing config file
|
||||||
```
|
```
|
||||||
|
|
||||||
### `infra`
|
### `infra`
|
||||||
|
|
||||||
```
|
```
|
||||||
qai-cli infra setup Deploy the CDK stack
|
qc-cli infra setup Deploy the CDK stack
|
||||||
qai-cli infra setup --no-bootstrap Deploy without running CDK bootstrap
|
qc-cli infra setup --no-bootstrap Deploy without running CDK bootstrap
|
||||||
qai-cli infra setup --cloudformation-execution-policy <arn> Set CDK bootstrap execution policy ARN
|
qc-cli infra setup --cloudformation-execution-policy <arn> Set CDK bootstrap execution policy ARN
|
||||||
qai-cli infra status Show CDK stack/resource status
|
qc-cli infra status Show CDK stack/resource status
|
||||||
qai-cli infra destroy Destroy stack, retaining S3 data
|
qc-cli infra destroy Destroy stack, retaining S3 data
|
||||||
qai-cli infra destroy --yes Destroy stack without confirmation
|
qc-cli infra destroy --yes Destroy stack without confirmation
|
||||||
qai-cli infra destroy --delete-bucket-data Destroy stack and delete S3 data
|
qc-cli infra destroy --delete-bucket-data Destroy stack and delete S3 data
|
||||||
```
|
```
|
||||||
|
|
||||||
|
`--cloudformation-execution-policy` is a one-time CDK bootstrap option, not a `config.yaml` setting. Pass it on `infra setup` when you need the CDK bootstrap CloudFormation execution role to use a policy other than the default `AdministratorAccess`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
qc-cli infra setup --cloudformation-execution-policy arn:aws:iam::aws:policy/PowerUserAccess
|
||||||
|
```
|
||||||
|
|
||||||
|
### `mlflow`
|
||||||
|
|
||||||
|
```
|
||||||
|
qc-cli mlflow open Open a presigned MLflow UI URL
|
||||||
|
qc-cli mlflow upload-metrics [job-name] Upload completed training metrics
|
||||||
|
```
|
||||||
|
|
||||||
|
`mlflow upload-metrics` defaults to the last submitted training job. It creates or recovers the job's MLflow run, imports `training_metrics.json` from the SageMaker model artifact, and records successful upload in `.qc-cli.json`. Use `--force` to upload the metrics again.
|
||||||
|
|
||||||
|
### `upload`
|
||||||
|
|
||||||
|
```
|
||||||
|
qc-cli upload <file> Upload a single file to S3
|
||||||
|
qc-cli upload <dir> Upload all files in a directory tree to S3
|
||||||
|
qc-cli upload <file> --s3-key <key> Upload a file to a custom S3 key
|
||||||
|
```
|
||||||
|
|
||||||
|
Uploads use `s3.bucket` and `s3.data_prefix` from `config.yaml`. File uploads default to `s3://<bucket>/<data_prefix>/<filename>`. Directory uploads are recursive, preserve paths relative to the uploaded directory, and place files under `s3://<bucket>/<data_prefix>/`.
|
||||||
|
|
||||||
|
### `train`
|
||||||
|
|
||||||
|
```
|
||||||
|
qc-cli train start Submit a SageMaker training job
|
||||||
|
qc-cli train start --upload-metrics Submit, wait, and upload metrics
|
||||||
|
qc-cli train status [job-name] Show job status; defaults to the last submitted job
|
||||||
|
qc-cli train list List recent training jobs
|
||||||
|
qc-cli train list --limit 3 Show a custom number of recent jobs
|
||||||
|
```
|
||||||
|
|
||||||
|
`train start` uses `s3://<bucket>/<data_prefix>/` as the training channel and writes outputs under `s3://<bucket>/<model_prefix>/`. If `sagemaker.training.source_dir` is set, the CLI packages that directory, uploads it beside the job output prefix, and passes `sagemaker_program`/`sagemaker_submit_directory` to the SageMaker container.
|
||||||
|
|
||||||
|
`train start --upload-metrics` checks SageMaker every 30 seconds by default, then uploads metrics after completion. Use `--poll-interval <seconds>` to choose another positive interval. Stopping the local command does not stop the SageMaker job.
|
||||||
|
|
||||||
|
The expected output artifact is SageMaker’s `model.tar.gz`, normally containing the trained model file your container writes to `/opt/ml/model`.
|
||||||
|
|
||||||
|
### `ai-hub`
|
||||||
|
|
||||||
|
```
|
||||||
|
qc-cli ai-hub upload <calibration.npz|calibration-dir> <inputs.npz|inputs.npy>
|
||||||
|
qc-cli ai-hub upload <calibration> <inputs> --from-step validate
|
||||||
|
qc-cli ai-hub optimize [--onnx-path PATH] [--model-s3-uri URI] [--from-job NAME]
|
||||||
|
qc-cli ai-hub quantize <calibration.npz|calibration-dir> [--model-id ID] [--onnx-path PATH] [--model-s3-uri URI] [--from-job NAME]
|
||||||
|
qc-cli ai-hub compile [--model-id ID] [--onnx-path PATH] [--model-s3-uri URI] [--from-job NAME]
|
||||||
|
qc-cli ai-hub validate <inputs.npz|inputs.npy> [--model-id ID] [--input-name NAME]
|
||||||
|
qc-cli ai-hub profile [--model-id ID]
|
||||||
|
qc-cli ai-hub download [--model-id ID] [--output PATH]
|
||||||
|
```
|
||||||
|
|
||||||
|
`ai-hub upload` optimizes to ONNX, quantizes, validates, and profiles. When `aihub.target_runtime` is not `onnx`, it also compiles the quantized model to that deployment runtime. The initial ONNX optimization gives external models Workbench provenance and applies compiler optimization passes before quantization.
|
||||||
|
|
||||||
|
Resume behavior:
|
||||||
|
|
||||||
|
```text
|
||||||
|
--from-step optimize Run optimize, quantize, optional final compile, validate, and profile.
|
||||||
|
--from-step quantize Quantize the last optimized ONNX, then optionally compile, validate, and profile.
|
||||||
|
--from-step compile Skip optimize and quantize; finalize the last quantized model for the target runtime.
|
||||||
|
--from-step validate Skip optimize, quantize, and compile; validate the last compiled model.
|
||||||
|
--from-step profile Skip optimize, quantize, compile, and validate; profile the last compiled model.
|
||||||
|
```
|
||||||
|
|
||||||
|
When a step runs in the current command, `upload` passes its returned model ID directly to the next step. When a step is skipped, the next step resolves the needed model ID from `.qc-cli.json`. This avoids re-running earlier AI Hub jobs when you only need to continue from a later step.
|
||||||
|
|
||||||
|
`ai-hub optimize` compiles an external model with `--target_runtime onnx`. `ai-hub quantize` uses an explicit `--model-id`, the last optimized ONNX model, or an explicit/local model source in that order. `ai-hub compile` resolves model sources in this order: `--model-id`, explicit source options, last quantized model, then the last training job. For `target_runtime: onnx`, upload treats the quantized ONNX as the final model and skips a redundant second compile. `ai-hub download` remains separate because downloading is outside the Workbench processing loop.
|
||||||
|
|
||||||
|
AI Hub authentication currently uses the local `qai-hub` SDK configuration. A planned follow-up is to support AWS Systems Manager Parameter Store `SecureString` for team-managed tokens, where `config.yaml` stores only a parameter name such as `/qc-cli/aihub/token`, AWS KMS encrypts the token at rest, and the CLI retrieves it at runtime with `ssm:GetParameter` plus `kms:Decrypt` permissions.
|
||||||
|
|
||||||
|
## Model lifecycle
|
||||||
|
|
||||||
|
The CLI uses neutral experiment naming for trained artifacts and reserves release terminology for an explicit promotion step.
|
||||||
|
|
||||||
|
Current behavior:
|
||||||
|
|
||||||
|
1. `qc-cli train start` submits a SageMaker training job.
|
||||||
|
2. `qc-cli train status` reads and displays SageMaker status only; it does not contact MLflow.
|
||||||
|
3. `qc-cli train start --upload-metrics` polls every 30 seconds by default, then uploads per-epoch metrics after completion.
|
||||||
|
4. `qc-cli mlflow upload-metrics [job-name]` uploads or retries metrics for an existing completed job.
|
||||||
|
5. The metrics upload workflow finalizes the MLflow run and, when `mlflow.register_trained_models` is enabled, registers the SageMaker `model.tar.gz` as a new MLflow model version with:
|
||||||
|
- `qc_cli.stage=experiment`
|
||||||
|
- `qc_cli.artifact_kind=trained_source`
|
||||||
|
- `qc_cli.source=sagemaker`
|
||||||
|
6. The MLflow alias `experiment-latest` points at the most recently registered experiment version.
|
||||||
|
7. AI Hub upload commands create deployable derived artifacts from a trained-source experiment or local ONNX model.
|
||||||
|
|
||||||
|
Training scripts can include a `training_metrics.json` file in the SageMaker model directory. When present, the explicit metrics upload command logs its ordered metrics to the associated MLflow run using each epoch as the MLflow step and stores the JSON as a run artifact:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"schema_version": 1,
|
||||||
|
"steps": [
|
||||||
|
{"step": 0, "metrics": {"val.precision": 0.72, "val.recall": 0.68}}
|
||||||
|
],
|
||||||
|
"summary": {"summary.best_epoch": 0}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Metric names must be non-empty strings, values must be finite numbers, and steps must be non-negative, unique, and strictly increasing. If the file is missing, the command uploads the final metrics reported by SageMaker and continues model registration without per-epoch history. A malformed metrics artifact still fails the upload command without affecting the trained model or model registration.
|
||||||
|
|
||||||
|
Future release aliases such as `v1` or `production` can point at a selected deployable artifact.
|
||||||
|
|
||||||
|
Example future metadata:
|
||||||
|
|
||||||
|
```text
|
||||||
|
qc-cli-model version 12
|
||||||
|
qc_cli.stage=experiment
|
||||||
|
qc_cli.artifact_kind=trained_source
|
||||||
|
qc_cli.source=sagemaker
|
||||||
|
|
||||||
|
qc-cli-model-aihub version 3
|
||||||
|
qc_cli.stage=ai_hub_compiled
|
||||||
|
qc_cli.artifact_kind=deployable
|
||||||
|
qc_cli.parent_registered_model_name=qc-cli-model
|
||||||
|
qc_cli.parent_model_version=12
|
||||||
|
qc_cli.runtime=tflite
|
||||||
|
qc_cli.quantization=int8
|
||||||
|
qc_cli.target_device=Samsung Galaxy S25
|
||||||
|
```
|
||||||
|
|
||||||
|
In that flow, `experiment-latest` remains a training convenience alias. Release selection is a separate promotion decision based on the derived artifact, not on the experiment name.
|
||||||
|
|
||||||
## AWS permissions required
|
## AWS permissions required
|
||||||
|
|
||||||
The IAM user or role running the CLI needs:
|
The IAM user or role running the CLI needs:
|
||||||
@@ -101,6 +272,7 @@ The IAM user or role running the CLI needs:
|
|||||||
| CreateRole, GetRole, DeleteRole, AttachRolePolicy, DetachRolePolicy | IAM |
|
| CreateRole, GetRole, DeleteRole, AttachRolePolicy, DetachRolePolicy | IAM |
|
||||||
| CreateStack, UpdateStack, DeleteStack, DescribeStacks, DescribeStackEvents | CloudFormation |
|
| CreateStack, UpdateStack, DeleteStack, DescribeStacks, DescribeStackEvents | CloudFormation |
|
||||||
| GetCallerIdentity | STS |
|
| GetCallerIdentity | STS |
|
||||||
|
| CreateTrainingJob, DescribeTrainingJob, ListTrainingJobs | SageMaker AI |
|
||||||
| CreateMlflowTrackingServer, DescribeMlflowTrackingServer, DeleteMlflowTrackingServer | SageMaker AI, when `mlflow.mode` is `create` or `existing` |
|
| CreateMlflowTrackingServer, DescribeMlflowTrackingServer, DeleteMlflowTrackingServer | SageMaker AI, when `mlflow.mode` is `create` or `existing` |
|
||||||
|
|
||||||
`AdministratorAccess` covers all of the above.
|
`AdministratorAccess` covers all of the above.
|
||||||
|
|||||||
8
app.py
8
app.py
@@ -3,22 +3,24 @@ import os
|
|||||||
import aws_cdk as cdk
|
import aws_cdk as cdk
|
||||||
|
|
||||||
from src.commands.utils import load_config
|
from src.commands.utils import load_config
|
||||||
from src.infra.stack import QaiStack
|
from src.infra.stack import QCStack
|
||||||
|
|
||||||
app = cdk.App()
|
app = cdk.App()
|
||||||
|
|
||||||
config_path = app.node.try_get_context("config") or "config.yaml"
|
config_path = app.node.try_get_context("config") or "config.yaml"
|
||||||
stack_name = app.node.try_get_context("stack_name") or "QaiCliStack"
|
|
||||||
account_id = app.node.try_get_context("account_id") or os.getenv("CDK_DEFAULT_ACCOUNT")
|
account_id = app.node.try_get_context("account_id") or os.getenv("CDK_DEFAULT_ACCOUNT")
|
||||||
delete_bucket_data = str(app.node.try_get_context("delete_bucket_data") or "false").lower() == "true"
|
delete_bucket_data = str(app.node.try_get_context("delete_bucket_data") or "false").lower() == "true"
|
||||||
|
|
||||||
cfg = load_config(config_path)
|
cfg = load_config(config_path)
|
||||||
|
stack_name = app.node.try_get_context("stack_name") or cfg.infra.stack_name
|
||||||
|
bootstrap_qualifier = app.node.try_get_context("bootstrap_qualifier") or cfg.infra.effective_bootstrap_qualifier
|
||||||
|
|
||||||
QaiStack(
|
QCStack(
|
||||||
app,
|
app,
|
||||||
stack_name,
|
stack_name,
|
||||||
config=cfg,
|
config=cfg,
|
||||||
delete_bucket_data=delete_bucket_data,
|
delete_bucket_data=delete_bucket_data,
|
||||||
|
synthesizer=cdk.DefaultStackSynthesizer(qualifier=bootstrap_qualifier),
|
||||||
env=cdk.Environment(
|
env=cdk.Environment(
|
||||||
account=account_id,
|
account=account_id,
|
||||||
region=cfg.aws.region,
|
region=cfg.aws.region,
|
||||||
|
|||||||
285
examples/meter-detection/README.md
Normal file
285
examples/meter-detection/README.md
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
# YOLO26 Electric Meter Detection Example
|
||||||
|
|
||||||
|
This example trains a YOLO26 object detection model on the Roboflow Universe electric meter dataset using the existing `qc-cli` SageMaker training flow.
|
||||||
|
|
||||||
|
The workflow is intentionally command driven. Run each step yourself so you can inspect the dataset, update `config.yaml`, and decide when to submit the SageMaker job.
|
||||||
|
|
||||||
|
Dataset:
|
||||||
|
|
||||||
|
```text
|
||||||
|
https://universe.roboflow.com/kemals-workspace-kbc8l/electric-meter-detection-o4tfi/dataset/1
|
||||||
|
```
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
- Install or sync the project dependencies: `uv sync`
|
||||||
|
- The virtual environment is activated.
|
||||||
|
- AWS credentials configured for the profile in `config.yaml`
|
||||||
|
- Infrastructure already deployed with `qc-cli infra setup`
|
||||||
|
|
||||||
|
## 1. Download The Dataset
|
||||||
|
|
||||||
|
Register or sign in to Roboflow, then open the dataset page:
|
||||||
|
|
||||||
|
```text
|
||||||
|
https://universe.roboflow.com/kemals-workspace-kbc8l/electric-meter-detection-o4tfi/dataset/1
|
||||||
|
```
|
||||||
|
|
||||||
|
Download the dataset in YOLOv26 format from the Roboflow UI, then extract the downloaded archive into:
|
||||||
|
|
||||||
|
```text
|
||||||
|
examples/meter-detection/data/electric-meter-detection
|
||||||
|
```
|
||||||
|
|
||||||
|
The `data.yaml` file should be directly under that folder:
|
||||||
|
|
||||||
|
```text
|
||||||
|
examples/meter-detection/data/electric-meter-detection/data.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
Do not move `data.yaml` into the `train/` split folder.
|
||||||
|
|
||||||
|
After extracting, confirm the dataset has a YOLO data file and image splits:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
find examples/meter-detection/data/electric-meter-detection -maxdepth 2 -type d | sort
|
||||||
|
find examples/meter-detection/data/electric-meter-detection -name data.yaml -print
|
||||||
|
```
|
||||||
|
|
||||||
|
Open `examples/meter-detection/data/electric-meter-detection/data.yaml` and make sure the split paths are relative to that folder:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
path: .
|
||||||
|
train: train/images
|
||||||
|
val: valid/images
|
||||||
|
test: test/images
|
||||||
|
```
|
||||||
|
|
||||||
|
If your downloaded dataset does not include a `test/` folder, remove the `test:` line.
|
||||||
|
|
||||||
|
The expected layout is similar to:
|
||||||
|
|
||||||
|
```text
|
||||||
|
examples/meter-detection/data/electric-meter-detection/
|
||||||
|
data.yaml
|
||||||
|
train/
|
||||||
|
valid/
|
||||||
|
test/
|
||||||
|
```
|
||||||
|
|
||||||
|
## 2. Configure SageMaker Training
|
||||||
|
|
||||||
|
Update `config.yaml` so the training section points at this example's source directory:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
sagemaker:
|
||||||
|
training:
|
||||||
|
image_uri: 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.6-cpu-py312-ubuntu22.04-sagemaker-v1
|
||||||
|
instance_type: ml.g4dn.xlarge
|
||||||
|
instance_count: 1
|
||||||
|
source_dir: examples/meter-detection/source
|
||||||
|
entry_point: train.py
|
||||||
|
hyperparameters:
|
||||||
|
model: yolo26n.pt
|
||||||
|
epochs: 25
|
||||||
|
imgsz: 640
|
||||||
|
batch: 16
|
||||||
|
workers: 2
|
||||||
|
```
|
||||||
|
|
||||||
|
Use `yolo26n.pt` for a lightweight first YOLO26 run. If those weights are unavailable in the installed Ultralytics package, use `yolo11n.pt` as the established fallback:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model: yolo11n.pt
|
||||||
|
```
|
||||||
|
|
||||||
|
The `source/requirements.txt` file is installed by the SageMaker PyTorch container before running `train.py`.
|
||||||
|
|
||||||
|
For a CPU smoke test, use a CPU instance and reduce the workload:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
sagemaker:
|
||||||
|
training:
|
||||||
|
image_uri: 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.6-cpu-py312-ubuntu22.04-sagemaker-v1
|
||||||
|
instance_type: ml.m4.xlarge
|
||||||
|
instance_count: 1
|
||||||
|
source_dir: examples/meter-detection/source
|
||||||
|
entry_point: train.py
|
||||||
|
hyperparameters:
|
||||||
|
model: yolo26n.pt
|
||||||
|
epochs: 1
|
||||||
|
imgsz: 320
|
||||||
|
batch: 4
|
||||||
|
workers: 2
|
||||||
|
```
|
||||||
|
|
||||||
|
## 3. Check Infrastructure
|
||||||
|
|
||||||
|
Confirm the CLI can see the configured SageMaker role and S3 bucket:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
qc-cli infra status
|
||||||
|
```
|
||||||
|
|
||||||
|
## 4. Upload The Dataset
|
||||||
|
|
||||||
|
Upload the downloaded Roboflow dataset to the `s3.data_prefix` configured in `config.yaml`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
qc-cli upload examples/meter-detection/data/electric-meter-detection
|
||||||
|
```
|
||||||
|
|
||||||
|
Directory uploads preserve paths relative to the uploaded directory, so SageMaker receives the dataset root with `data.yaml` plus the split directories.
|
||||||
|
|
||||||
|
In SageMaker, this uploaded dataset root is mounted at `/opt/ml/input/data/train`. That `train` path is the SageMaker channel name, not the YOLO `train/` split folder.
|
||||||
|
|
||||||
|
## 5. Start Training
|
||||||
|
|
||||||
|
Submit the SageMaker training job:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
qc-cli train start
|
||||||
|
```
|
||||||
|
|
||||||
|
The command prints the submitted SageMaker job name. Check progress with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
qc-cli train status
|
||||||
|
```
|
||||||
|
|
||||||
|
Or pass the job name explicitly:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
qc-cli train status qc-cli-YYYYMMDD-HHMMSS
|
||||||
|
```
|
||||||
|
|
||||||
|
To submit the job, wait for completion, and automatically import metrics and register the model, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
qc-cli train start --upload-metrics
|
||||||
|
```
|
||||||
|
|
||||||
|
The default polling interval is 30 seconds. It can be changed with `--poll-interval <seconds>`.
|
||||||
|
|
||||||
|
The metrics can be also submitted using:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
qc-cli mlflow upload-metrics
|
||||||
|
```
|
||||||
|
|
||||||
|
## SageMaker Outputs
|
||||||
|
|
||||||
|
When the job completes, SageMaker packages the files written under `/opt/ml/model` into `model.tar.gz`.
|
||||||
|
|
||||||
|
This example writes:
|
||||||
|
|
||||||
|
```text
|
||||||
|
best.pt
|
||||||
|
model.onnx
|
||||||
|
metrics.json
|
||||||
|
training_metrics.json
|
||||||
|
```
|
||||||
|
|
||||||
|
The archive is stored under the configured `s3.model_prefix`.
|
||||||
|
|
||||||
|
The `mlflow upload-metrics` command imports `training_metrics.json`, which provides per-epoch training and validation
|
||||||
|
losses, precision, recall, mAP@0.50, mAP@0.50:0.95, and learning rates. For object detection, mAP and precision/recall
|
||||||
|
are more meaningful than classification accuracy when assessing model quality.
|
||||||
|
|
||||||
|
## 6. Configure Qualcomm AI Hub
|
||||||
|
|
||||||
|
Authenticate with Qualcomm AI Hub:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
qai-hub configure --api_token
|
||||||
|
```
|
||||||
|
|
||||||
|
Add AI Hub settings to `config.yaml`. The input name and image size must match the ONNX model exported by this example:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
aihub:
|
||||||
|
device:
|
||||||
|
name: Dragonwing IQ-9075 EVK
|
||||||
|
target_runtime: onnx
|
||||||
|
input_specs:
|
||||||
|
images: [[1, 3, 640, 640], float32]
|
||||||
|
job_name: meter-detection
|
||||||
|
model_name: meter-detection
|
||||||
|
output_dir: build/qai-hub/meter-detection
|
||||||
|
```
|
||||||
|
|
||||||
|
The ONNX graph is the source of truth. The export normally uses the same value as `sagemaker.training.hyperparameters.imgsz`, but changing `config.yaml` after training does not resize an existing model. For example, a model exported with `imgsz: 320` requires `images: [[1, 3, 320, 320], float32]`.
|
||||||
|
|
||||||
|
## 7. Prepare AI Hub Inputs
|
||||||
|
|
||||||
|
Generate calibration samples and a validation input from the downloaded dataset:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python examples/meter-detection/prepare_aihub_inputs.py --image-size 640
|
||||||
|
```
|
||||||
|
|
||||||
|
This writes:
|
||||||
|
|
||||||
|
```text
|
||||||
|
examples/meter-detection/data/aihub_calibration/*.npy
|
||||||
|
examples/meter-detection/data/inputs.npz
|
||||||
|
```
|
||||||
|
|
||||||
|
The script applies the preprocessing expected by the exported YOLO model: aspect-ratio-preserving letterboxing, RGB channel order, channel-first layout, and pixel values normalized to `[0, 1]`.
|
||||||
|
|
||||||
|
## 8. Upload To Qualcomm AI Hub
|
||||||
|
|
||||||
|
Use the SageMaker job name printed by `qc-cli train start`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
qc-cli ai-hub upload \
|
||||||
|
examples/meter-detection/data/aihub_calibration \
|
||||||
|
examples/meter-detection/data/inputs.npz \
|
||||||
|
--from-job qc-cli-YYYYMMDD-HHMMSS
|
||||||
|
```
|
||||||
|
|
||||||
|
The command downloads the job's `model.tar.gz`, finds `model.onnx`, and runs the following AI Hub workflow:
|
||||||
|
|
||||||
|
1. Compile the external ONNX to a Workbench-optimized ONNX model.
|
||||||
|
2. Quantize the optimized ONNX model.
|
||||||
|
3. Compile the quantized model when the configured deployment runtime is not `onnx`.
|
||||||
|
4. Validate and profile the final model.
|
||||||
|
|
||||||
|
The training example sanitizes the Ultralytics ONNX export before saving `model.onnx`. This removes graph input or output names, such as `output0`, that are duplicated in the ONNX `value_info` metadata and rejected by AI Hub.
|
||||||
|
|
||||||
|
For a model already downloaded by a failed upload attempt, sanitize the extracted ONNX file and retry using the local model. Replace the job name in both paths:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run --with onnx python examples/meter-detection/source/sanitize_onnx.py \
|
||||||
|
build/qai-hub/meter-detection/qc-cli-YYYYMMDD-HHMMSS/source/extracted/model.onnx \
|
||||||
|
--output build/qai-hub/meter-detection/model.aihub.onnx
|
||||||
|
|
||||||
|
qc-cli ai-hub upload \
|
||||||
|
examples/meter-detection/data/aihub_calibration \
|
||||||
|
examples/meter-detection/data/inputs.npz \
|
||||||
|
--onnx-path build/qai-hub/meter-detection/model.aihub.onnx
|
||||||
|
```
|
||||||
|
|
||||||
|
Download the compiled artifact after the workflow completes:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
qc-cli ai-hub download --output build/qai-hub/meter-detection/model.tflite
|
||||||
|
```
|
||||||
|
|
||||||
|
## Training Hyperparameters
|
||||||
|
|
||||||
|
Values under `sagemaker.training.hyperparameters` are passed to `source/train.py` as command-line arguments.
|
||||||
|
|
||||||
|
| Name | Type | Default | Description |
|
||||||
|
|---|---:|---:|---|
|
||||||
|
| `model` | string | `yolo26n.pt` | Ultralytics model weights or model YAML. |
|
||||||
|
| `epochs` | int | `25` | Number of training epochs. |
|
||||||
|
| `imgsz` | int | `640` | Square training image size. |
|
||||||
|
| `batch` | int | `16` | Images per training batch. |
|
||||||
|
| `workers` | int | `2` | DataLoader worker count. |
|
||||||
|
| `patience` | int | `20` | Early stopping patience. |
|
||||||
|
| `device` | string | auto | Optional Ultralytics device value such as `0` or `cpu`. |
|
||||||
|
| `data-yaml` | string | auto | Optional path to `data.yaml`; normally discovered from the uploaded dataset root. |
|
||||||
|
| `dataset-dir` | string | `SM_CHANNEL_TRAIN` | Uploaded dataset root mounted by SageMaker. |
|
||||||
|
|
||||||
|
Do not set `dataset-dir` or `model-dir` in normal SageMaker runs. SageMaker sets those automatically through `SM_CHANNEL_TRAIN` and `SM_MODEL_DIR`.
|
||||||
92
examples/meter-detection/prepare_aihub_inputs.py
Normal file
92
examples/meter-detection/prepare_aihub_inputs.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Prepare Qualcomm AI Hub calibration and validation inputs for the meter detector."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png"}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("examples/meter-detection/data/electric-meter-detection"),
|
||||||
|
help="Root of the extracted Roboflow dataset.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--calibration-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("examples/meter-detection/data/aihub_calibration"),
|
||||||
|
help="Directory where .npy calibration samples will be written.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input-file",
|
||||||
|
type=Path,
|
||||||
|
default=Path("examples/meter-detection/data/inputs.npz"),
|
||||||
|
help="Validation .npz input file for qc-cli ai-hub validate.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--input-name", default="images", help="ONNX input name.")
|
||||||
|
parser.add_argument("--image-size", type=int, default=640, help="Square image size used for ONNX export.")
|
||||||
|
parser.add_argument("--samples", type=int, default=16, help="Number of calibration samples to write.")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_image(path: Path, image_size: int) -> np.ndarray:
|
||||||
|
"""Apply Ultralytics-style letterboxing and produce an NCHW float32 tensor."""
|
||||||
|
with Image.open(path) as source:
|
||||||
|
image = source.convert("RGB")
|
||||||
|
|
||||||
|
scale = min(image_size / image.width, image_size / image.height)
|
||||||
|
resized_width = round(image.width * scale)
|
||||||
|
resized_height = round(image.height * scale)
|
||||||
|
image = image.resize((resized_width, resized_height), Image.Resampling.BILINEAR)
|
||||||
|
|
||||||
|
canvas = Image.new("RGB", (image_size, image_size), (114, 114, 114))
|
||||||
|
left = round((image_size - resized_width) / 2 - 0.1)
|
||||||
|
top = round((image_size - resized_height) / 2 - 0.1)
|
||||||
|
canvas.paste(image, (left, top))
|
||||||
|
|
||||||
|
array = np.asarray(canvas, dtype=np.float32) / 255.0
|
||||||
|
return np.transpose(array, (2, 0, 1))[None, ...].astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
if args.image_size < 1:
|
||||||
|
raise SystemExit("--image-size must be at least 1")
|
||||||
|
if args.samples < 1:
|
||||||
|
raise SystemExit("--samples must be at least 1")
|
||||||
|
|
||||||
|
images = sorted(
|
||||||
|
path
|
||||||
|
for path in args.dataset_dir.rglob("*")
|
||||||
|
if path.is_file() and path.suffix.lower() in IMAGE_EXTENSIONS and path.parent.name == "images"
|
||||||
|
)
|
||||||
|
if not images:
|
||||||
|
raise SystemExit(f"No images found under {args.dataset_dir}")
|
||||||
|
|
||||||
|
args.calibration_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
args.input_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
for stale_sample in args.calibration_dir.glob("sample_*.npy"):
|
||||||
|
stale_sample.unlink()
|
||||||
|
|
||||||
|
prepared: list[np.ndarray] = []
|
||||||
|
for index, image_path in enumerate(images[: args.samples]):
|
||||||
|
sample = preprocess_image(image_path, args.image_size)
|
||||||
|
np.save(args.calibration_dir / f"sample_{index:03d}.npy", sample)
|
||||||
|
prepared.append(sample)
|
||||||
|
|
||||||
|
np.savez(args.input_file, **{args.input_name: prepared[0]}) # pyright: ignore[reportArgumentType]
|
||||||
|
print(f"Wrote {len(prepared)} calibration samples to {args.calibration_dir}")
|
||||||
|
print(f"Wrote validation input to {args.input_file}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
3
examples/meter-detection/source/requirements.txt
Normal file
3
examples/meter-detection/source/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
ultralytics>=8.3.0
|
||||||
|
pyyaml>=6.0.3
|
||||||
|
onnx>=1.16.0
|
||||||
38
examples/meter-detection/source/sanitize_onnx.py
Normal file
38
examples/meter-detection/source/sanitize_onnx.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Remove ONNX value_info entries that duplicate graph inputs or outputs."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import onnx # type: ignore[reportMissingImports]
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_onnx(path: Path, output_path: Path | None = None) -> Path:
|
||||||
|
model = onnx.load(path)
|
||||||
|
io_names = {value.name for value in (*model.graph.input, *model.graph.output)}
|
||||||
|
retained_value_info = [value for value in model.graph.value_info if value.name not in io_names]
|
||||||
|
|
||||||
|
destination = output_path or path
|
||||||
|
if len(retained_value_info) != len(model.graph.value_info):
|
||||||
|
del model.graph.value_info[:]
|
||||||
|
model.graph.value_info.extend(retained_value_info)
|
||||||
|
|
||||||
|
destination.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
onnx.save(model, destination)
|
||||||
|
return destination
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument("onnx_path", type=Path)
|
||||||
|
parser.add_argument("--output", type=Path)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
written = sanitize_onnx(args.onnx_path, args.output)
|
||||||
|
print(f"Saved sanitized ONNX model to {written}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
128
examples/meter-detection/source/train.py
Normal file
128
examples/meter-detection/source/train.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""SageMaker entry point for YOLO electric meter detection training."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from sanitize_onnx import sanitize_onnx
|
||||||
|
from training_metrics import write_training_metrics
|
||||||
|
from ultralytics import YOLO # type: ignore[reportMissingImports]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--model", default="yolo26n.pt")
|
||||||
|
parser.add_argument("--epochs", type=int, default=25)
|
||||||
|
parser.add_argument("--imgsz", type=int, default=640)
|
||||||
|
parser.add_argument("--batch", type=int, default=16)
|
||||||
|
parser.add_argument("--workers", type=int, default=2)
|
||||||
|
parser.add_argument("--patience", type=int, default=20)
|
||||||
|
parser.add_argument("--device", default=None)
|
||||||
|
parser.add_argument("--data-yaml", default=None)
|
||||||
|
parser.add_argument("--dataset-dir", default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
|
||||||
|
parser.add_argument("--train-dir", dest="dataset_dir", help=argparse.SUPPRESS)
|
||||||
|
parser.add_argument("--model-dir", default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def find_data_yaml(dataset_dir: Path, explicit_path: str | None) -> Path:
|
||||||
|
if explicit_path:
|
||||||
|
data_yaml = Path(explicit_path)
|
||||||
|
if data_yaml.is_file():
|
||||||
|
return data_yaml
|
||||||
|
raise FileNotFoundError(f"Configured data.yaml does not exist: {data_yaml}")
|
||||||
|
|
||||||
|
matches = sorted(dataset_dir.rglob("data.yaml"))
|
||||||
|
if not matches:
|
||||||
|
raise FileNotFoundError(f"Could not find data.yaml under {dataset_dir}")
|
||||||
|
if len(matches) > 1:
|
||||||
|
print(f"Found multiple data.yaml files; using {matches[0]}")
|
||||||
|
return matches[0]
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_data_yaml(data_yaml: Path) -> Path:
|
||||||
|
"""Write a SageMaker-local data file rooted at the uploaded dataset."""
|
||||||
|
dataset_root = data_yaml.parent
|
||||||
|
data = yaml.safe_load(data_yaml.read_text(encoding="utf-8"))
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise ValueError(f"Expected a mapping in {data_yaml}")
|
||||||
|
|
||||||
|
normalized = dict(data)
|
||||||
|
normalized["path"] = str(dataset_root)
|
||||||
|
if "val" not in normalized and "valid" in normalized:
|
||||||
|
normalized["val"] = normalized.pop("valid")
|
||||||
|
|
||||||
|
prepared_path = dataset_root / "data.sagemaker.yaml"
|
||||||
|
prepared_path.write_text(yaml.safe_dump(normalized, sort_keys=False), encoding="utf-8")
|
||||||
|
print(f"Prepared dataset config: {prepared_path}")
|
||||||
|
return prepared_path
|
||||||
|
|
||||||
|
|
||||||
|
def copy_if_exists(source: Path, destination: Path) -> None:
|
||||||
|
if source.exists():
|
||||||
|
shutil.copy2(source, destination)
|
||||||
|
print(f"Saved {destination}")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
dataset_dir = Path(args.dataset_dir)
|
||||||
|
model_dir = Path(args.model_dir)
|
||||||
|
model_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
data_yaml = prepare_data_yaml(find_data_yaml(dataset_dir, args.data_yaml))
|
||||||
|
model = YOLO(args.model)
|
||||||
|
|
||||||
|
train_kwargs: dict[str, Any] = {
|
||||||
|
"data": str(data_yaml),
|
||||||
|
"epochs": args.epochs,
|
||||||
|
"imgsz": args.imgsz,
|
||||||
|
"batch": args.batch,
|
||||||
|
"workers": args.workers,
|
||||||
|
"patience": args.patience,
|
||||||
|
"project": str(model_dir / "runs"),
|
||||||
|
"name": "train",
|
||||||
|
"exist_ok": True,
|
||||||
|
}
|
||||||
|
if args.device:
|
||||||
|
train_kwargs["device"] = args.device
|
||||||
|
|
||||||
|
results = model.train(**train_kwargs)
|
||||||
|
save_dir = Path(results.save_dir)
|
||||||
|
best_pt = save_dir / "weights" / "best.pt"
|
||||||
|
last_pt = save_dir / "weights" / "last.pt"
|
||||||
|
trained_weights = best_pt if best_pt.exists() else last_pt
|
||||||
|
if not trained_weights.exists():
|
||||||
|
raise FileNotFoundError(f"Could not find trained weights in {save_dir / 'weights'}")
|
||||||
|
|
||||||
|
write_training_metrics(save_dir / "results.csv", model_dir / "training_metrics.json")
|
||||||
|
copy_if_exists(trained_weights, model_dir / "best.pt")
|
||||||
|
trained_model = YOLO(str(trained_weights))
|
||||||
|
onnx_path = Path(trained_model.export(format="onnx", imgsz=args.imgsz))
|
||||||
|
saved_onnx_path = sanitize_onnx(onnx_path, model_dir / "model.onnx")
|
||||||
|
print(f"Saved {saved_onnx_path}")
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"model": args.model,
|
||||||
|
"epochs": args.epochs,
|
||||||
|
"imgsz": args.imgsz,
|
||||||
|
"batch": args.batch,
|
||||||
|
"workers": args.workers,
|
||||||
|
"patience": args.patience,
|
||||||
|
"data_yaml": str(data_yaml),
|
||||||
|
"weights": str(trained_weights),
|
||||||
|
"onnx": str(saved_onnx_path),
|
||||||
|
}
|
||||||
|
(model_dir / "metrics.json").write_text(json.dumps(metrics, indent=2), encoding="utf-8")
|
||||||
|
print(f"Saved model artifacts to {model_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
82
examples/meter-detection/source/training_metrics.py
Normal file
82
examples/meter-detection/source/training_metrics.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
METRIC_NAMES = {
|
||||||
|
"metrics/precision(B)": "val.precision",
|
||||||
|
"metrics/recall(B)": "val.recall",
|
||||||
|
"metrics/mAP50(B)": "val.map50",
|
||||||
|
"metrics/mAP50-95(B)": "val.map50_95",
|
||||||
|
"train/box_loss": "train.box_loss",
|
||||||
|
"train/cls_loss": "train.cls_loss",
|
||||||
|
"train/dfl_loss": "train.dfl_loss",
|
||||||
|
"val/box_loss": "val.box_loss",
|
||||||
|
"val/cls_loss": "val.cls_loss",
|
||||||
|
"val/dfl_loss": "val.dfl_loss",
|
||||||
|
"time": "train.elapsed_seconds",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def write_training_metrics(results_csv: Path, destination: Path) -> None:
|
||||||
|
steps = _read_metric_steps(results_csv)
|
||||||
|
summary = _build_summary(steps)
|
||||||
|
payload = {
|
||||||
|
"schema_version": 1,
|
||||||
|
"steps": steps,
|
||||||
|
"summary": summary,
|
||||||
|
}
|
||||||
|
destination.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
||||||
|
print(f"Saved {destination}")
|
||||||
|
|
||||||
|
|
||||||
|
def _read_metric_steps(results_csv: Path) -> list[dict[str, Any]]:
|
||||||
|
if not results_csv.is_file():
|
||||||
|
raise FileNotFoundError(f"Could not find Ultralytics metrics history: {results_csv}")
|
||||||
|
|
||||||
|
steps: list[dict[str, Any]] = []
|
||||||
|
with results_csv.open(newline="", encoding="utf-8") as csv_file:
|
||||||
|
for row_index, raw_row in enumerate(csv.DictReader(csv_file)):
|
||||||
|
row = {str(key).strip(): value for key, value in raw_row.items()}
|
||||||
|
raw_epoch = row.pop("epoch", row_index)
|
||||||
|
step = int(float(raw_epoch))
|
||||||
|
metrics: dict[str, float] = {}
|
||||||
|
for source_name, raw_value in row.items():
|
||||||
|
if raw_value is None or not raw_value.strip():
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
value = float(raw_value)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
if math.isfinite(value):
|
||||||
|
metrics[METRIC_NAMES.get(source_name, _normalize_metric_name(source_name))] = value
|
||||||
|
steps.append({"step": step, "metrics": metrics})
|
||||||
|
return steps
|
||||||
|
|
||||||
|
|
||||||
|
def _build_summary(steps: list[dict[str, Any]]) -> dict[str, float]:
|
||||||
|
if not steps:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
summary: dict[str, float] = {}
|
||||||
|
final_step = steps[-1]
|
||||||
|
summary["summary.final_epoch"] = float(final_step["step"])
|
||||||
|
for name, value in final_step["metrics"].items():
|
||||||
|
summary[f"summary.final.{name}"] = value
|
||||||
|
|
||||||
|
scored_steps = [step for step in steps if "val.map50_95" in step["metrics"]]
|
||||||
|
if scored_steps:
|
||||||
|
best_step = max(scored_steps, key=lambda step: step["metrics"]["val.map50_95"])
|
||||||
|
summary["summary.best_epoch"] = float(best_step["step"])
|
||||||
|
summary["summary.best_val.map50_95"] = best_step["metrics"]["val.map50_95"]
|
||||||
|
if "val.map50" in best_step["metrics"]:
|
||||||
|
summary["summary.best_val.map50"] = best_step["metrics"]["val.map50"]
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_metric_name(name: str) -> str:
|
||||||
|
normalized = name.replace("/", ".")
|
||||||
|
normalized = re.sub(r"[^A-Za-z0-9_.-]+", "_", normalized)
|
||||||
|
return normalized.strip("._") or "unnamed"
|
||||||
@@ -3,21 +3,25 @@ requires = ["hatchling"]
|
|||||||
build-backend = "hatchling.build"
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "qai-cli"
|
name = "qc-cli"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
description = "CLI for SageMaker ONNX training and Qualcomm AI Hub optimization"
|
description = "CLI for training and deploying models for Qualcomm AI Hub"
|
||||||
requires-python = ">=3.13"
|
requires-python = ">=3.13"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aws-cdk-lib>=2.180.0",
|
"aws-cdk-lib>=2.180.0",
|
||||||
"typer==0.25.0",
|
"typer==0.25.0",
|
||||||
"boto3>=1.34,<1.42",
|
"boto3>=1.34,<1.42",
|
||||||
"constructs>=10.0.0",
|
"constructs>=10.0.0",
|
||||||
|
"mlflow>=3.0",
|
||||||
|
"numpy>=1.26",
|
||||||
"pydantic>=2.13.3",
|
"pydantic>=2.13.3",
|
||||||
"pyyaml>=6.0.3",
|
"pyyaml>=6.0.3",
|
||||||
|
"qai-hub>=0.49.0",
|
||||||
|
"sagemaker-mlflow>=0.4.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
qai-cli = "src.main:app"
|
qc-cli = "src.main:app"
|
||||||
|
|
||||||
[tool.hatch.build.targets.wheel]
|
[tool.hatch.build.targets.wheel]
|
||||||
packages = ["src"]
|
packages = ["src"]
|
||||||
|
|||||||
@@ -3,13 +3,11 @@ from typing import Any
|
|||||||
import boto3
|
import boto3
|
||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError
|
||||||
|
|
||||||
from src.infra.provisioning import STACK_NAME
|
|
||||||
|
|
||||||
|
def stack_status(region: str, profile: str, stack_name: str) -> dict[str, Any] | None:
|
||||||
def stack_status(region: str, profile: str) -> dict[str, Any] | None:
|
|
||||||
client = boto3.Session(profile_name=profile, region_name=region).client("cloudformation")
|
client = boto3.Session(profile_name=profile, region_name=region).client("cloudformation")
|
||||||
try:
|
try:
|
||||||
stack = client.describe_stacks(StackName=STACK_NAME)["Stacks"][0]
|
stack = client.describe_stacks(StackName=stack_name)["Stacks"][0]
|
||||||
except ClientError as e:
|
except ClientError as e:
|
||||||
message = e.response.get("Error", {}).get("Message", "")
|
message = e.response.get("Error", {}).get("Message", "")
|
||||||
if "does not exist" in message:
|
if "does not exist" in message:
|
||||||
|
|||||||
17
src/aws/iam.py
Normal file
17
src/aws/iam.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
import boto3
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
from mypy_boto3_iam import IAMClient
|
||||||
|
|
||||||
|
|
||||||
|
def _client(profile: str) -> IAMClient:
|
||||||
|
return boto3.Session(profile_name=profile).client("iam")
|
||||||
|
|
||||||
|
|
||||||
|
def get_role_arn(profile: str, role_name: str) -> str | None:
|
||||||
|
client = _client(profile)
|
||||||
|
try:
|
||||||
|
return client.get_role(RoleName=role_name)["Role"]["Arn"]
|
||||||
|
except ClientError as e:
|
||||||
|
if e.response.get("Error", {}).get("Code") == "NoSuchEntity":
|
||||||
|
return None
|
||||||
|
raise
|
||||||
@@ -1,3 +1,6 @@
|
|||||||
|
import os
|
||||||
|
from collections.abc import Generator
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
@@ -17,3 +20,55 @@ def describe_tracking_server(region: str, profile: str, name: str) -> dict[str,
|
|||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def get_tracking_server_arn(region: str, profile: str, name: str) -> str:
|
||||||
|
server = describe_tracking_server(region, profile, name)
|
||||||
|
if not server:
|
||||||
|
raise ValueError(f"MLflow tracking server not found: {name}")
|
||||||
|
|
||||||
|
arn = server.get("TrackingServerArn")
|
||||||
|
if not arn:
|
||||||
|
raise ValueError(f"MLflow tracking server has no ARN: {name}")
|
||||||
|
return str(arn)
|
||||||
|
|
||||||
|
|
||||||
|
def create_presigned_tracking_server_url(region: str, profile: str, name: str) -> str:
|
||||||
|
client = boto3.Session(profile_name=profile, region_name=region).client("sagemaker")
|
||||||
|
response = client.create_presigned_mlflow_tracking_server_url(TrackingServerName=name)
|
||||||
|
return str(response["AuthorizedUrl"])
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def tracking_auth_env(profile: str, region: str) -> Generator[None]:
|
||||||
|
credentials = boto3.Session(profile_name=profile, region_name=region).get_credentials()
|
||||||
|
if credentials is None:
|
||||||
|
raise RuntimeError(f"AWS credentials could not be resolved for profile '{profile}'.")
|
||||||
|
|
||||||
|
frozen_credentials = credentials.get_frozen_credentials()
|
||||||
|
if not frozen_credentials.access_key or not frozen_credentials.secret_key:
|
||||||
|
raise RuntimeError(f"AWS credentials are incomplete for profile '{profile}'.")
|
||||||
|
|
||||||
|
env_updates = {
|
||||||
|
"AWS_PROFILE": profile,
|
||||||
|
"AWS_DEFAULT_REGION": region,
|
||||||
|
"AWS_REGION": region,
|
||||||
|
"AWS_ACCESS_KEY_ID": frozen_credentials.access_key,
|
||||||
|
"AWS_SECRET_ACCESS_KEY": frozen_credentials.secret_key,
|
||||||
|
}
|
||||||
|
if frozen_credentials.token:
|
||||||
|
env_updates["AWS_SESSION_TOKEN"] = frozen_credentials.token
|
||||||
|
|
||||||
|
restore_keys = set(env_updates) | {"AWS_SESSION_TOKEN"}
|
||||||
|
previous_env = {key: os.environ.get(key) for key in restore_keys}
|
||||||
|
try:
|
||||||
|
os.environ.update(env_updates)
|
||||||
|
if not frozen_credentials.token:
|
||||||
|
os.environ.pop("AWS_SESSION_TOKEN", None)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
for key, value in previous_env.items():
|
||||||
|
if value is None:
|
||||||
|
os.environ.pop(key, None)
|
||||||
|
else:
|
||||||
|
os.environ[key] = value
|
||||||
|
|||||||
69
src/aws/s3.py
Normal file
69
src/aws/s3.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
from mypy_boto3_s3 import S3Client
|
||||||
|
|
||||||
|
|
||||||
|
def _client(region: str, profile: str) -> S3Client:
|
||||||
|
return boto3.Session(profile_name=profile, region_name=region).client("s3")
|
||||||
|
|
||||||
|
|
||||||
|
def upload_file(
|
||||||
|
region: str,
|
||||||
|
profile: str,
|
||||||
|
bucket: str,
|
||||||
|
local_path: str,
|
||||||
|
s3_key: str,
|
||||||
|
) -> str:
|
||||||
|
_client(region, profile).upload_file(local_path, bucket, s3_key)
|
||||||
|
return f"s3://{bucket}/{s3_key}"
|
||||||
|
|
||||||
|
|
||||||
|
def download_file(
|
||||||
|
region: str,
|
||||||
|
profile: str,
|
||||||
|
s3_uri: str,
|
||||||
|
local_path: str,
|
||||||
|
) -> str:
|
||||||
|
if not s3_uri.startswith("s3://"):
|
||||||
|
raise ValueError(f"Expected S3 URI, got: {s3_uri}")
|
||||||
|
bucket_key = s3_uri.removeprefix("s3://")
|
||||||
|
bucket, _, key = bucket_key.partition("/")
|
||||||
|
if not bucket or not key:
|
||||||
|
raise ValueError(f"Expected S3 URI with bucket and key, got: {s3_uri}")
|
||||||
|
dest = Path(local_path)
|
||||||
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
_client(region, profile).download_file(bucket, key, str(dest))
|
||||||
|
return str(dest)
|
||||||
|
|
||||||
|
|
||||||
|
def upload_dir(
|
||||||
|
region: str,
|
||||||
|
profile: str,
|
||||||
|
bucket: str,
|
||||||
|
local_dir: str,
|
||||||
|
s3_prefix: str,
|
||||||
|
on_progress: Callable[[], None] | None = None,
|
||||||
|
) -> int:
|
||||||
|
root = Path(local_dir)
|
||||||
|
files = [file for file in root.rglob("*") if file.is_file()]
|
||||||
|
if not files:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
client = _client(region, profile)
|
||||||
|
prefix = s3_prefix.rstrip("/")
|
||||||
|
|
||||||
|
def upload_one(file_path: Path) -> None:
|
||||||
|
key = f"{prefix}/{file_path.relative_to(root)}"
|
||||||
|
client.upload_file(str(file_path), bucket, key)
|
||||||
|
if on_progress:
|
||||||
|
on_progress()
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=10) as pool:
|
||||||
|
futures = [pool.submit(upload_one, file) for file in files]
|
||||||
|
for future in as_completed(futures):
|
||||||
|
future.result()
|
||||||
|
|
||||||
|
return len(files)
|
||||||
143
src/aws/sagemaker.py
Normal file
143
src/aws/sagemaker.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
from mypy_boto3_sagemaker import SageMakerClient
|
||||||
|
from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
|
||||||
|
from mypy_boto3_sagemaker.type_defs import (
|
||||||
|
CreateTrainingJobRequestTypeDef,
|
||||||
|
ResourceConfigTypeDef,
|
||||||
|
TrainingJobSummaryTypeDef,
|
||||||
|
)
|
||||||
|
|
||||||
|
from src.config import Boto3SessionKwargs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TrainingJobRequest:
|
||||||
|
role_arn: str
|
||||||
|
image_uri: str
|
||||||
|
instance_type: TrainingInstanceTypeType
|
||||||
|
instance_count: int
|
||||||
|
s3_train_uri: str
|
||||||
|
s3_output_path: str
|
||||||
|
job_name: str
|
||||||
|
hyperparameters: dict[str, Any] = field(default_factory=dict)
|
||||||
|
entry_point: str | None = None
|
||||||
|
source_dir: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TrainingJobStatus:
|
||||||
|
name: str
|
||||||
|
status: str
|
||||||
|
created: datetime | None
|
||||||
|
modified: datetime | None
|
||||||
|
model_artifacts: str | None
|
||||||
|
failure_reason: str | None
|
||||||
|
raw: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
def _sm(session: Boto3SessionKwargs) -> SageMakerClient:
|
||||||
|
return boto3.Session(**session).client("sagemaker")
|
||||||
|
|
||||||
|
|
||||||
|
def _upload_source_dir(
|
||||||
|
session: Boto3SessionKwargs,
|
||||||
|
source_dir: str,
|
||||||
|
s3_output_path: str,
|
||||||
|
job_name: str,
|
||||||
|
) -> str:
|
||||||
|
import io
|
||||||
|
import tarfile
|
||||||
|
|
||||||
|
buf = io.BytesIO()
|
||||||
|
with tarfile.open(fileobj=buf, mode="w:gz") as tar:
|
||||||
|
tar.add(source_dir, arcname=".")
|
||||||
|
buf.seek(0)
|
||||||
|
|
||||||
|
without_scheme = s3_output_path.removeprefix("s3://")
|
||||||
|
bucket, _, prefix = without_scheme.partition("/")
|
||||||
|
key = f"{prefix.rstrip('/')}/{job_name}/source/sourcedir.tar.gz".lstrip("/")
|
||||||
|
|
||||||
|
boto3.Session(**session).client("s3").upload_fileobj(buf, bucket, key)
|
||||||
|
return f"s3://{bucket}/{key}"
|
||||||
|
|
||||||
|
|
||||||
|
def start_training_job(session: Boto3SessionKwargs, job: TrainingJobRequest) -> str:
|
||||||
|
hp = {k: str(v) for k, v in job.hyperparameters.items()}
|
||||||
|
|
||||||
|
if job.source_dir:
|
||||||
|
s3_code_uri = _upload_source_dir(
|
||||||
|
session,
|
||||||
|
job.source_dir,
|
||||||
|
job.s3_output_path,
|
||||||
|
job.job_name,
|
||||||
|
)
|
||||||
|
hp["sagemaker_program"] = job.entry_point or "train.py"
|
||||||
|
hp["sagemaker_submit_directory"] = s3_code_uri
|
||||||
|
|
||||||
|
resource_config: ResourceConfigTypeDef = {
|
||||||
|
"InstanceType": job.instance_type,
|
||||||
|
"InstanceCount": job.instance_count,
|
||||||
|
"VolumeSizeInGB": 30,
|
||||||
|
}
|
||||||
|
request: CreateTrainingJobRequestTypeDef = {
|
||||||
|
"TrainingJobName": job.job_name,
|
||||||
|
"AlgorithmSpecification": {"TrainingImage": job.image_uri, "TrainingInputMode": "File"},
|
||||||
|
"RoleArn": job.role_arn,
|
||||||
|
"InputDataConfig": [
|
||||||
|
{
|
||||||
|
"ChannelName": "train",
|
||||||
|
"DataSource": {
|
||||||
|
"S3DataSource": {
|
||||||
|
"S3DataType": "S3Prefix",
|
||||||
|
"S3Uri": job.s3_train_uri,
|
||||||
|
"S3DataDistributionType": "FullyReplicated",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"OutputDataConfig": {"S3OutputPath": job.s3_output_path},
|
||||||
|
"ResourceConfig": resource_config,
|
||||||
|
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
|
||||||
|
"HyperParameters": hp,
|
||||||
|
}
|
||||||
|
_sm(session).create_training_job(**request)
|
||||||
|
return job.job_name
|
||||||
|
|
||||||
|
|
||||||
|
def get_training_job_status(session: Boto3SessionKwargs, job_name: str) -> TrainingJobStatus:
|
||||||
|
resp = _sm(session).describe_training_job(TrainingJobName=job_name)
|
||||||
|
return TrainingJobStatus(
|
||||||
|
name=resp["TrainingJobName"],
|
||||||
|
status=resp["TrainingJobStatus"],
|
||||||
|
created=resp.get("CreationTime"),
|
||||||
|
modified=resp.get("LastModifiedTime"),
|
||||||
|
model_artifacts=resp.get("ModelArtifacts", {}).get("S3ModelArtifacts"),
|
||||||
|
failure_reason=resp.get("FailureReason"),
|
||||||
|
raw=dict(resp),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_artifacts(region: str, profile: str, job_name: str) -> str:
|
||||||
|
resp = boto3.Session(profile_name=profile, region_name=region).client("sagemaker").describe_training_job(
|
||||||
|
TrainingJobName=job_name
|
||||||
|
)
|
||||||
|
artifact = resp.get("ModelArtifacts", {}).get("S3ModelArtifacts")
|
||||||
|
if not artifact:
|
||||||
|
raise RuntimeError(f"Training job '{job_name}' does not have model artifacts yet.")
|
||||||
|
return str(artifact)
|
||||||
|
|
||||||
|
|
||||||
|
def list_training_jobs(
|
||||||
|
session: Boto3SessionKwargs,
|
||||||
|
max_results: int = 10,
|
||||||
|
) -> list[TrainingJobSummaryTypeDef]:
|
||||||
|
resp = _sm(session).list_training_jobs(
|
||||||
|
SortBy="CreationTime",
|
||||||
|
SortOrder="Descending",
|
||||||
|
MaxResults=max_results,
|
||||||
|
)
|
||||||
|
return list(resp["TrainingJobSummaries"])
|
||||||
1
src/cloud/__init__.py
Normal file
1
src/cloud/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Cloud provider adapters."""
|
||||||
77
src/cloud/mlflow.py
Normal file
77
src/cloud/mlflow.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
from contextlib import AbstractContextManager
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
from src.aws import mlflow as aws_mlflow
|
||||||
|
from src.config import Config
|
||||||
|
|
||||||
|
|
||||||
|
class MlflowTrackingBackend(Protocol):
|
||||||
|
@property
|
||||||
|
def provider_name(self) -> str: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def profile(self) -> str: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def region(self) -> str: ...
|
||||||
|
|
||||||
|
def get_tracking_uri(self, tracking_server_name: str) -> str: ...
|
||||||
|
|
||||||
|
def auth_env(self) -> AbstractContextManager[None]: ...
|
||||||
|
|
||||||
|
def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]: ...
|
||||||
|
|
||||||
|
def training_run_tags(self, training_job: Any) -> dict[str, Any]: ...
|
||||||
|
|
||||||
|
def training_status_params(self, training_job_status: Any) -> dict[str, Any]: ...
|
||||||
|
|
||||||
|
def model_version_tags(self, training_job_status: Any) -> dict[str, Any]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AwsMlflowTrackingBackend:
|
||||||
|
profile: str
|
||||||
|
region: str
|
||||||
|
provider_name: str = "aws"
|
||||||
|
|
||||||
|
def get_tracking_uri(self, tracking_server_name: str) -> str:
|
||||||
|
return aws_mlflow.get_tracking_server_arn(self.region, self.profile, tracking_server_name)
|
||||||
|
|
||||||
|
def auth_env(self) -> AbstractContextManager[None]:
|
||||||
|
return aws_mlflow.tracking_auth_env(self.profile, self.region)
|
||||||
|
|
||||||
|
def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"provider.name": self.provider_name,
|
||||||
|
"provider.region": region,
|
||||||
|
"provider.profile": profile,
|
||||||
|
"sagemaker.role_arn": role_arn,
|
||||||
|
"sagemaker.job_name": training_job.job_name,
|
||||||
|
"sagemaker.training_image": training_job.image_uri,
|
||||||
|
"sagemaker.instance_type": training_job.instance_type,
|
||||||
|
"sagemaker.instance_count": training_job.instance_count,
|
||||||
|
"sagemaker.s3_train_uri": training_job.s3_train_uri,
|
||||||
|
"sagemaker.s3_output_path": training_job.s3_output_path,
|
||||||
|
"sagemaker.entry_point": training_job.entry_point,
|
||||||
|
"sagemaker.source_dir": training_job.source_dir,
|
||||||
|
}
|
||||||
|
|
||||||
|
def training_run_tags(self, training_job: Any) -> dict[str, Any]:
|
||||||
|
return {"sagemaker.job_name": training_job.job_name}
|
||||||
|
|
||||||
|
def training_status_params(self, training_job_status: Any) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"sagemaker.training_status": training_job_status.status,
|
||||||
|
"sagemaker.created_at": training_job_status.created,
|
||||||
|
"sagemaker.modified_at": training_job_status.modified,
|
||||||
|
"sagemaker.model_artifacts": training_job_status.model_artifacts,
|
||||||
|
"sagemaker.failure_reason": training_job_status.failure_reason,
|
||||||
|
}
|
||||||
|
|
||||||
|
def model_version_tags(self, training_job_status: Any) -> dict[str, Any]:
|
||||||
|
return {"sagemaker.job_name": training_job_status.name}
|
||||||
|
|
||||||
|
|
||||||
|
def mlflow_tracking_backend_from_config(cfg: Config) -> MlflowTrackingBackend:
|
||||||
|
return AwsMlflowTrackingBackend(profile=cfg.aws.profile, region=cfg.aws.region)
|
||||||
567
src/commands/ai_hub.py
Normal file
567
src/commands/ai_hub.py
Normal file
@@ -0,0 +1,567 @@
|
|||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import StrEnum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import qai_hub.hub as hub
|
||||||
|
import typer
|
||||||
|
from qai_hub.client import Device
|
||||||
|
|
||||||
|
from src import state as state_ops
|
||||||
|
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
||||||
|
from src.config import Config
|
||||||
|
from src.qualcomm import aihub_jobs
|
||||||
|
from src.qualcomm.artifacts import ResolvedOnnx, resolve_onnx
|
||||||
|
|
||||||
|
app = typer.Typer(help="Optimize, quantize, compile, validate, profile, and download models with Qualcomm Workbench")
|
||||||
|
|
||||||
|
_RUNTIME_EXTENSIONS = {
|
||||||
|
"tflite": "tflite",
|
||||||
|
"qnn_context_binary": "bin",
|
||||||
|
"onnx": "onnx",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class UploadStep(StrEnum):
|
||||||
|
optimize = "optimize"
|
||||||
|
quantize = "quantize"
|
||||||
|
compile = "compile"
|
||||||
|
validate = "validate"
|
||||||
|
profile = "profile"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ResolvedModelSource:
|
||||||
|
model: str | Path
|
||||||
|
model_artifact: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _input_specs(cfg: Config) -> dict[str, tuple[tuple[int, ...], str]]:
|
||||||
|
specs = {name: (tuple(shape), dtype) for name, (shape, dtype) in cfg.aihub.input_specs.items()}
|
||||||
|
if not specs:
|
||||||
|
CONSOLE.print("[red]aihub.input_specs must define at least one input.[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
return specs
|
||||||
|
|
||||||
|
|
||||||
|
def _load_inputs(
|
||||||
|
input_file: Path,
|
||||||
|
specs: Mapping[str, tuple[Sequence[int], str]],
|
||||||
|
input_name: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
if not input_file.exists():
|
||||||
|
raise FileNotFoundError(f"File not found: {input_file}")
|
||||||
|
|
||||||
|
if input_file.suffix == ".npz":
|
||||||
|
loaded = np.load(input_file)
|
||||||
|
missing = set(specs) - set(loaded.files)
|
||||||
|
if missing:
|
||||||
|
raise ValueError(f"Missing input(s) in NPZ: {', '.join(sorted(missing))}")
|
||||||
|
return {name: loaded[name] for name in specs}
|
||||||
|
|
||||||
|
if input_file.suffix == ".npy":
|
||||||
|
if input_name is None:
|
||||||
|
if len(specs) != 1:
|
||||||
|
raise ValueError("--input-name is required when config has multiple inputs")
|
||||||
|
input_name = next(iter(specs))
|
||||||
|
if input_name not in specs:
|
||||||
|
raise ValueError(f"Input name '{input_name}' is not defined in aihub.input_specs")
|
||||||
|
return {input_name: np.load(input_file)}
|
||||||
|
|
||||||
|
raise ValueError("Input file must be .npz or .npy")
|
||||||
|
|
||||||
|
|
||||||
|
def _load_calibration(path: Path, specs: Mapping[str, tuple[Sequence[int], str]]) -> dict[str, Any]:
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
if path.is_file():
|
||||||
|
return _load_inputs(path, specs)
|
||||||
|
|
||||||
|
if not path.is_dir():
|
||||||
|
raise FileNotFoundError(f"Calibration path not found: {path}")
|
||||||
|
|
||||||
|
if len(specs) != 1:
|
||||||
|
raise ValueError("Directory calibration data is supported only for single-input models.")
|
||||||
|
input_name = next(iter(specs))
|
||||||
|
samples = [np.load(p) for p in sorted(path.glob("*.npy"))]
|
||||||
|
if not samples:
|
||||||
|
raise ValueError(f"No .npy calibration samples found in {path}")
|
||||||
|
return {input_name: samples}
|
||||||
|
|
||||||
|
|
||||||
|
def _job_name(cfg: Config, operation: str) -> str | None:
|
||||||
|
if not cfg.aihub.job_name:
|
||||||
|
return None
|
||||||
|
return f"{cfg.aihub.job_name}-{operation}"
|
||||||
|
|
||||||
|
|
||||||
|
def _model_id_or_state(config_path: str, model_id: str | None, *, quantized: bool = False) -> str:
|
||||||
|
st = state_ops.store(config_path)
|
||||||
|
resolved = model_id or (st.get_last_quantized_model_id() if quantized else st.get_last_compiled_model_id())
|
||||||
|
if not resolved:
|
||||||
|
source = "quantized" if quantized else "compiled"
|
||||||
|
CONSOLE.print(f"[red]No {source} model found. Pass --model-id or run the previous AI Hub step first.[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_model_source(
|
||||||
|
cfg: Config,
|
||||||
|
config_path: str,
|
||||||
|
*,
|
||||||
|
model_id: str | None = None,
|
||||||
|
previous_model_id: str | None = None,
|
||||||
|
from_job: str | None = None,
|
||||||
|
model_s3_uri: str | None = None,
|
||||||
|
onnx_path: str | None = None,
|
||||||
|
) -> ResolvedModelSource:
|
||||||
|
if model_id:
|
||||||
|
return ResolvedModelSource(model_id)
|
||||||
|
|
||||||
|
has_explicit_source = bool(from_job or model_s3_uri or onnx_path)
|
||||||
|
if previous_model_id and not has_explicit_source:
|
||||||
|
return ResolvedModelSource(previous_model_id)
|
||||||
|
|
||||||
|
resolved = _resolve_onnx_source(
|
||||||
|
cfg,
|
||||||
|
config_path,
|
||||||
|
from_job=from_job,
|
||||||
|
model_s3_uri=model_s3_uri,
|
||||||
|
onnx_path=onnx_path,
|
||||||
|
)
|
||||||
|
return ResolvedModelSource(resolved.onnx_path, resolved.model_artifact)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_onnx_source(
|
||||||
|
cfg: Config,
|
||||||
|
config_path: str,
|
||||||
|
*,
|
||||||
|
from_job: str | None = None,
|
||||||
|
model_s3_uri: str | None = None,
|
||||||
|
onnx_path: str | None = None,
|
||||||
|
) -> ResolvedOnnx:
|
||||||
|
st = state_ops.store(config_path)
|
||||||
|
last_training_job = st.get_last_training_job()
|
||||||
|
saved_model_artifact = None
|
||||||
|
if not from_job and not model_s3_uri and not onnx_path and not last_training_job:
|
||||||
|
saved_model_artifact = st.get_last_model_artifact()
|
||||||
|
|
||||||
|
return resolve_onnx(
|
||||||
|
cfg=cfg,
|
||||||
|
output_dir=cfg.aihub.output_dir,
|
||||||
|
from_job=from_job,
|
||||||
|
model_s3_uri=model_s3_uri or saved_model_artifact,
|
||||||
|
onnx_path=onnx_path,
|
||||||
|
last_training_job=last_training_job,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _device_selector(device: Device) -> str:
|
||||||
|
parts: list[str] = []
|
||||||
|
if device.name:
|
||||||
|
parts.append(f"name={device.name!r}")
|
||||||
|
if device.os:
|
||||||
|
parts.append(f"os={device.os!r}")
|
||||||
|
if device.attributes:
|
||||||
|
parts.append(f"attributes={device.attributes!r}")
|
||||||
|
return ", ".join(parts) if parts else "empty selector"
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_device(cfg: Config) -> None:
|
||||||
|
device = cfg.aihub.device
|
||||||
|
try:
|
||||||
|
matches = hub.get_devices(name=device.name, os=device.os, attributes=device.attributes)
|
||||||
|
except Exception as e:
|
||||||
|
CONSOLE.print(f"[red]Unable to validate AI Hub device {_device_selector(device)}: {e}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
if matches:
|
||||||
|
return
|
||||||
|
|
||||||
|
CONSOLE.print(f"[red]AI Hub device not found: {_device_selector(device)}[/red]")
|
||||||
|
CONSOLE.print("Run [bold]qai-hub list-devices[/bold] to see valid device names.")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def _quantize_step(
|
||||||
|
cfg: Config,
|
||||||
|
config_path: str,
|
||||||
|
calibration_path: Path,
|
||||||
|
*,
|
||||||
|
model_id: str | None = None,
|
||||||
|
from_job: str | None = None,
|
||||||
|
model_s3_uri: str | None = None,
|
||||||
|
onnx_path: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
st = state_ops.store(config_path)
|
||||||
|
specs = _input_specs(cfg)
|
||||||
|
try:
|
||||||
|
source = _resolve_model_source(
|
||||||
|
cfg,
|
||||||
|
config_path,
|
||||||
|
model_id=model_id,
|
||||||
|
previous_model_id=st.get_last_optimized_model_id(),
|
||||||
|
from_job=from_job,
|
||||||
|
model_s3_uri=model_s3_uri,
|
||||||
|
onnx_path=onnx_path,
|
||||||
|
)
|
||||||
|
calibration_data = _load_calibration(calibration_path, specs)
|
||||||
|
except (FileNotFoundError, ValueError) as e:
|
||||||
|
CONSOLE.print(f"[red]{e}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
hub_model = (
|
||||||
|
hub.upload_model(str(source.model), name=cfg.aihub.model_name)
|
||||||
|
if isinstance(source.model, Path)
|
||||||
|
else hub.get_model(source.model)
|
||||||
|
)
|
||||||
|
result = aihub_jobs.submit_quantize_job(
|
||||||
|
hub_model,
|
||||||
|
calibration_data,
|
||||||
|
cfg.aihub.quantize_options,
|
||||||
|
job_name=_job_name(cfg, "quantize"),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
CONSOLE.print(f"[red]AI Hub quantize failed: {e}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
updates: dict[str, Any] = {
|
||||||
|
"last_quantize_job_id": result["job_id"],
|
||||||
|
"last_quantized_model_id": result["model_id"],
|
||||||
|
}
|
||||||
|
if source.model_artifact:
|
||||||
|
updates["last_model_artifact"] = source.model_artifact
|
||||||
|
st.update(**updates)
|
||||||
|
CONSOLE.print(f"[green]✓[/green] Quantize job: [bold]{result['job_id']}[/bold]")
|
||||||
|
CONSOLE.print(f"[green]✓[/green] Quantized model: [bold]{result['model_id']}[/bold]")
|
||||||
|
return str(result["model_id"])
|
||||||
|
|
||||||
|
|
||||||
|
def _optimize_step(
|
||||||
|
cfg: Config,
|
||||||
|
config_path: str,
|
||||||
|
from_job: str | None,
|
||||||
|
model_s3_uri: str | None,
|
||||||
|
onnx_path: str | None,
|
||||||
|
) -> str:
|
||||||
|
st = state_ops.store(config_path)
|
||||||
|
_validate_device(cfg)
|
||||||
|
specs = _input_specs(cfg)
|
||||||
|
try:
|
||||||
|
source = _resolve_onnx_source(
|
||||||
|
cfg,
|
||||||
|
config_path,
|
||||||
|
from_job=from_job,
|
||||||
|
model_s3_uri=model_s3_uri,
|
||||||
|
onnx_path=onnx_path,
|
||||||
|
)
|
||||||
|
except (FileNotFoundError, ValueError) as e:
|
||||||
|
CONSOLE.print(f"[red]{e}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
hub_model = hub.upload_model(str(source.onnx_path), name=cfg.aihub.model_name)
|
||||||
|
result = aihub_jobs.submit_compile_job(
|
||||||
|
model=hub_model,
|
||||||
|
device=cfg.aihub.device,
|
||||||
|
input_specs=specs,
|
||||||
|
target_runtime="onnx",
|
||||||
|
job_name=_job_name(cfg, "optimize"),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
CONSOLE.print(f"[red]AI Hub ONNX optimization failed: {e}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
st.update(
|
||||||
|
last_model_artifact=source.model_artifact,
|
||||||
|
last_optimize_job_id=result["job_id"],
|
||||||
|
last_optimized_model_id=result["model_id"],
|
||||||
|
)
|
||||||
|
CONSOLE.print(f"[green]✓[/green] ONNX optimization job: [bold]{result['job_id']}[/bold]")
|
||||||
|
CONSOLE.print(f"[green]✓[/green] Optimized ONNX model: [bold]{result['model_id']}[/bold]")
|
||||||
|
return str(result["model_id"])
|
||||||
|
|
||||||
|
|
||||||
|
def _compile_step(
|
||||||
|
cfg: Config,
|
||||||
|
config_path: str,
|
||||||
|
*,
|
||||||
|
model_id: str | None = None,
|
||||||
|
from_job: str | None = None,
|
||||||
|
model_s3_uri: str | None = None,
|
||||||
|
onnx_path: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
st = state_ops.store(config_path)
|
||||||
|
_validate_device(cfg)
|
||||||
|
specs = _input_specs(cfg)
|
||||||
|
try:
|
||||||
|
source = _resolve_model_source(
|
||||||
|
cfg,
|
||||||
|
config_path,
|
||||||
|
model_id=model_id,
|
||||||
|
previous_model_id=st.get_last_quantized_model_id(),
|
||||||
|
from_job=from_job,
|
||||||
|
model_s3_uri=model_s3_uri,
|
||||||
|
onnx_path=onnx_path,
|
||||||
|
)
|
||||||
|
except (FileNotFoundError, ValueError) as e:
|
||||||
|
CONSOLE.print(f"[red]{e}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
hub_model = (
|
||||||
|
hub.upload_model(str(source.model), name=cfg.aihub.model_name)
|
||||||
|
if isinstance(source.model, Path)
|
||||||
|
else hub.get_model(source.model)
|
||||||
|
)
|
||||||
|
result = aihub_jobs.submit_compile_job(
|
||||||
|
model=hub_model,
|
||||||
|
device=cfg.aihub.device,
|
||||||
|
input_specs=specs,
|
||||||
|
target_runtime=cfg.aihub.target_runtime,
|
||||||
|
options=cfg.aihub.compile_options,
|
||||||
|
job_name=_job_name(cfg, "compile"),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
CONSOLE.print(f"[red]AI Hub compile failed: {e}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
updates: dict[str, Any] = {
|
||||||
|
"last_compile_job_id": result["job_id"],
|
||||||
|
"last_compiled_model_id": result["model_id"],
|
||||||
|
}
|
||||||
|
if source.model_artifact:
|
||||||
|
updates["last_model_artifact"] = source.model_artifact
|
||||||
|
st.update(**updates)
|
||||||
|
CONSOLE.print(f"[green]✓[/green] Compile job: [bold]{result['job_id']}[/bold]")
|
||||||
|
CONSOLE.print(f"[green]✓[/green] Compiled model: [bold]{result['model_id']}[/bold]")
|
||||||
|
return str(result["model_id"])
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_step(
|
||||||
|
cfg: Config,
|
||||||
|
config_path: str,
|
||||||
|
input_file: Path,
|
||||||
|
model_id: str | None,
|
||||||
|
input_name: str | None,
|
||||||
|
) -> str:
|
||||||
|
_validate_device(cfg)
|
||||||
|
specs = _input_specs(cfg)
|
||||||
|
resolved_model_id = _model_id_or_state(config_path, model_id)
|
||||||
|
try:
|
||||||
|
inputs = _load_inputs(input_file, specs, input_name)
|
||||||
|
except (FileNotFoundError, ValueError) as e:
|
||||||
|
CONSOLE.print(f"[red]{e}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
run = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||||
|
out_dir = Path(cfg.aihub.output_dir) / run / "validation"
|
||||||
|
try:
|
||||||
|
hub_model = hub.get_model(resolved_model_id)
|
||||||
|
result = aihub_jobs.submit_inference_job(
|
||||||
|
hub_model,
|
||||||
|
cfg.aihub.device,
|
||||||
|
inputs,
|
||||||
|
out_dir,
|
||||||
|
job_name=_job_name(cfg, "validate"),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
CONSOLE.print(f"[red]AI Hub inference failed: {e}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
state_ops.store(config_path).update(last_inference_job_id=result["job_id"])
|
||||||
|
CONSOLE.print(f"[green]✓[/green] Inference job: [bold]{result['job_id']}[/bold]")
|
||||||
|
outputs = result.get("outputs")
|
||||||
|
if isinstance(outputs, dict):
|
||||||
|
for name, value in outputs.items():
|
||||||
|
CONSOLE.print(f" {name}: shape={getattr(value, 'shape', '?')}")
|
||||||
|
CONSOLE.print(f"Outputs: [cyan]{out_dir}[/cyan]")
|
||||||
|
return str(result["job_id"])
|
||||||
|
|
||||||
|
|
||||||
|
def _profile_step(cfg: Config, config_path: str, model_id: str | None) -> str:
|
||||||
|
_validate_device(cfg)
|
||||||
|
resolved_model_id = _model_id_or_state(config_path, model_id)
|
||||||
|
try:
|
||||||
|
hub_model = hub.get_model(resolved_model_id)
|
||||||
|
result = aihub_jobs.submit_profile_job(
|
||||||
|
hub_model,
|
||||||
|
cfg.aihub.device,
|
||||||
|
cfg.aihub.profile_options,
|
||||||
|
job_name=_job_name(cfg, "profile"),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
CONSOLE.print(f"[red]AI Hub profile failed: {e}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
state_ops.store(config_path).update(last_profile_job_id=result["job_id"])
|
||||||
|
CONSOLE.print(f"[green]✓[/green] Profile job: [bold]{result['job_id']}[/bold]")
|
||||||
|
return str(result["job_id"])
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def optimize(
|
||||||
|
from_job: str | None = typer.Option(None, "--from-job", help="Training job name whose model artifact should optimize"),
|
||||||
|
model_s3_uri: str | None = typer.Option(None, "--model-s3-uri", help="S3 URI of model.tar.gz to optimize"),
|
||||||
|
onnx_path: str | None = typer.Option(
|
||||||
|
None, "--onnx-path", help="Local ONNX path or ONNX path inside extracted artifact"
|
||||||
|
),
|
||||||
|
config: str = CONFIG_OPT,
|
||||||
|
) -> None:
|
||||||
|
"""Optimize an external model into a Workbench-produced ONNX model."""
|
||||||
|
cfg = load_cfg(config)
|
||||||
|
_optimize_step(cfg, config, from_job, model_s3_uri, onnx_path)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def quantize(
|
||||||
|
calibration_path: Path = typer.Argument(..., help="Calibration .npz file or directory of .npy samples"),
|
||||||
|
model_id: str | None = typer.Option(None, "--model-id", help="AI Hub optimized ONNX model ID"),
|
||||||
|
from_job: str | None = typer.Option(None, "--from-job", help="Training job name whose model artifact should quantize"),
|
||||||
|
model_s3_uri: str | None = typer.Option(None, "--model-s3-uri", help="S3 URI of model.tar.gz to quantize"),
|
||||||
|
onnx_path: str | None = typer.Option(
|
||||||
|
None, "--onnx-path", help="Local ONNX path or ONNX path inside extracted artifact"
|
||||||
|
),
|
||||||
|
config: str = CONFIG_OPT,
|
||||||
|
) -> None:
|
||||||
|
"""Quantize an ONNX model to INT8."""
|
||||||
|
cfg = load_cfg(config)
|
||||||
|
_quantize_step(
|
||||||
|
cfg,
|
||||||
|
config,
|
||||||
|
calibration_path,
|
||||||
|
model_id=model_id,
|
||||||
|
from_job=from_job,
|
||||||
|
model_s3_uri=model_s3_uri,
|
||||||
|
onnx_path=onnx_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def compile(
|
||||||
|
model_id: str | None = typer.Option(None, "--model-id", help="AI Hub model ID to compile"),
|
||||||
|
from_job: str | None = typer.Option(None, "--from-job", help="Training job name whose model artifact should compile"),
|
||||||
|
model_s3_uri: str | None = typer.Option(None, "--model-s3-uri", help="S3 URI of model.tar.gz to compile"),
|
||||||
|
onnx_path: str | None = typer.Option(
|
||||||
|
None, "--onnx-path", help="Local ONNX path or ONNX path inside extracted artifact"
|
||||||
|
),
|
||||||
|
config: str = CONFIG_OPT,
|
||||||
|
) -> None:
|
||||||
|
"""Compile a model for the configured Qualcomm AI Hub target."""
|
||||||
|
cfg = load_cfg(config)
|
||||||
|
_compile_step(
|
||||||
|
cfg,
|
||||||
|
config,
|
||||||
|
model_id=model_id,
|
||||||
|
from_job=from_job,
|
||||||
|
model_s3_uri=model_s3_uri,
|
||||||
|
onnx_path=onnx_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def validate(
|
||||||
|
input_file: Path = typer.Argument(..., help="NumPy .npz or .npy inputs to run on device"),
|
||||||
|
model_id: str | None = typer.Option(None, "--model-id", help="AI Hub compiled model ID"),
|
||||||
|
input_name: str | None = typer.Option(None, "--input-name", help="Input name for .npy files"),
|
||||||
|
config: str = CONFIG_OPT,
|
||||||
|
) -> None:
|
||||||
|
"""Run an AI Hub inference job using sample inputs."""
|
||||||
|
cfg = load_cfg(config)
|
||||||
|
_validate_step(cfg, config, input_file, model_id, input_name)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def profile(
|
||||||
|
model_id: str | None = typer.Option(None, "--model-id", help="AI Hub compiled model ID"),
|
||||||
|
config: str = CONFIG_OPT,
|
||||||
|
) -> None:
|
||||||
|
"""Profile a compiled model on the configured AI Hub device."""
|
||||||
|
cfg = load_cfg(config)
|
||||||
|
_profile_step(cfg, config, model_id)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def upload(
|
||||||
|
calibration_path: Path = typer.Argument(..., help="Calibration .npz file or directory of .npy samples"),
|
||||||
|
input_file: Path = typer.Argument(..., help="Validation .npz or .npy inputs to run on device"),
|
||||||
|
from_step: UploadStep = typer.Option(UploadStep.optimize, "--from-step", help="Resume from this Workbench step"),
|
||||||
|
from_job: str | None = typer.Option(None, "--from-job", help="Training job name whose model artifact should upload"),
|
||||||
|
model_s3_uri: str | None = typer.Option(None, "--model-s3-uri", help="S3 URI of model.tar.gz to upload"),
|
||||||
|
onnx_path: str | None = typer.Option(
|
||||||
|
None, "--onnx-path", help="Local ONNX path or ONNX path inside extracted artifact"
|
||||||
|
),
|
||||||
|
input_name: str | None = typer.Option(None, "--input-name", help="Input name for .npy validation files"),
|
||||||
|
config: str = CONFIG_OPT,
|
||||||
|
) -> None:
|
||||||
|
"""Optimize, quantize, optionally compile, validate, and profile a model."""
|
||||||
|
cfg = load_cfg(config)
|
||||||
|
steps = [UploadStep.optimize, UploadStep.quantize, UploadStep.compile, UploadStep.validate, UploadStep.profile]
|
||||||
|
selected = steps[steps.index(from_step) :]
|
||||||
|
|
||||||
|
optimized_model_id: str | None = None
|
||||||
|
quantized_model_id: str | None = None
|
||||||
|
compiled_model_id: str | None = None
|
||||||
|
if UploadStep.optimize in selected:
|
||||||
|
optimized_model_id = _optimize_step(cfg, config, from_job, model_s3_uri, onnx_path)
|
||||||
|
if UploadStep.quantize in selected:
|
||||||
|
if UploadStep.optimize not in selected:
|
||||||
|
optimized_model_id = state_ops.store(config).get_last_optimized_model_id()
|
||||||
|
if not optimized_model_id:
|
||||||
|
CONSOLE.print(
|
||||||
|
"[red]No optimized ONNX model found. Resume from --from-step optimize or run "
|
||||||
|
"'qc-cli ai-hub optimize' first.[/red]"
|
||||||
|
)
|
||||||
|
raise typer.Exit(1)
|
||||||
|
quantized_model_id = _quantize_step(
|
||||||
|
cfg,
|
||||||
|
config,
|
||||||
|
calibration_path,
|
||||||
|
model_id=optimized_model_id,
|
||||||
|
)
|
||||||
|
if UploadStep.compile in selected:
|
||||||
|
if cfg.aihub.target_runtime == "onnx":
|
||||||
|
compiled_model_id = quantized_model_id or state_ops.store(config).get_last_quantized_model_id()
|
||||||
|
if not compiled_model_id:
|
||||||
|
CONSOLE.print(
|
||||||
|
"[red]No quantized ONNX model found. Resume from --from-step quantize or run "
|
||||||
|
"'qc-cli ai-hub quantize' first.[/red]"
|
||||||
|
)
|
||||||
|
raise typer.Exit(1)
|
||||||
|
state_ops.store(config).update(last_compiled_model_id=compiled_model_id)
|
||||||
|
CONSOLE.print("[green]✓[/green] Target runtime is ONNX; skipping final compile.")
|
||||||
|
else:
|
||||||
|
compiled_model_id = _compile_step(
|
||||||
|
cfg,
|
||||||
|
config,
|
||||||
|
model_id=quantized_model_id,
|
||||||
|
)
|
||||||
|
if UploadStep.validate in selected:
|
||||||
|
_validate_step(cfg, config, input_file, compiled_model_id, input_name)
|
||||||
|
if UploadStep.profile in selected:
|
||||||
|
_profile_step(cfg, config, compiled_model_id)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def download(
|
||||||
|
model_id: str | None = typer.Option(None, "--model-id", help="AI Hub compiled model ID"),
|
||||||
|
output: Path | None = typer.Option(None, "--output", "-o", help="Destination file path"),
|
||||||
|
config: str = CONFIG_OPT,
|
||||||
|
) -> None:
|
||||||
|
"""Download the last compiled deployable artifact from AI Hub."""
|
||||||
|
cfg = load_cfg(config)
|
||||||
|
resolved_model_id = _model_id_or_state(config, model_id)
|
||||||
|
ext = _RUNTIME_EXTENSIONS.get(cfg.aihub.target_runtime, cfg.aihub.target_runtime)
|
||||||
|
dest = output or (Path(cfg.aihub.output_dir) / f"model.{ext}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
written = aihub_jobs.download_model(resolved_model_id, dest)
|
||||||
|
except Exception as e:
|
||||||
|
CONSOLE.print(f"[red]AI Hub download failed: {e}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
state_ops.store(config).update(last_downloaded_model=written)
|
||||||
|
CONSOLE.print(f"[green]✓[/green] Downloaded model: [cyan]{written}[/cyan]")
|
||||||
@@ -51,6 +51,8 @@ def setup(
|
|||||||
profile=cfg.aws.profile,
|
profile=cfg.aws.profile,
|
||||||
account_id=account_id,
|
account_id=account_id,
|
||||||
region=cfg.aws.region,
|
region=cfg.aws.region,
|
||||||
|
bootstrap_qualifier=cfg.infra.effective_bootstrap_qualifier,
|
||||||
|
toolkit_stack_name=cfg.infra.effective_toolkit_stack_name,
|
||||||
cloudformation_execution_policy=cloudformation_execution_policy,
|
cloudformation_execution_policy=cloudformation_execution_policy,
|
||||||
)
|
)
|
||||||
with CONSOLE.status("Running cdk deploy..."):
|
with CONSOLE.status("Running cdk deploy..."):
|
||||||
@@ -58,6 +60,9 @@ def setup(
|
|||||||
profile=cfg.aws.profile,
|
profile=cfg.aws.profile,
|
||||||
account_id=account_id,
|
account_id=account_id,
|
||||||
region=cfg.aws.region,
|
region=cfg.aws.region,
|
||||||
|
stack_name=cfg.infra.stack_name,
|
||||||
|
bootstrap_qualifier=cfg.infra.effective_bootstrap_qualifier,
|
||||||
|
toolkit_stack_name=cfg.infra.effective_toolkit_stack_name,
|
||||||
config_path=config,
|
config_path=config,
|
||||||
config_dir=str(Path(config).parent),
|
config_dir=str(Path(config).parent),
|
||||||
config_snapshot=cfg.model_dump(mode="json"),
|
config_snapshot=cfg.model_dump(mode="json"),
|
||||||
@@ -72,7 +77,8 @@ def setup(
|
|||||||
if outputs.get("SageMakerRoleArn"):
|
if outputs.get("SageMakerRoleArn"):
|
||||||
CONSOLE.print(f"[green]✓[/green] IAM role: {outputs['SageMakerRoleArn']}")
|
CONSOLE.print(f"[green]✓[/green] IAM role: {outputs['SageMakerRoleArn']}")
|
||||||
if cfg.mlflow.mode is MlflowMode.create and outputs.get("MlflowTrackingServerArn"):
|
if cfg.mlflow.mode is MlflowMode.create and outputs.get("MlflowTrackingServerArn"):
|
||||||
CONSOLE.print(f"[green]✓[/green] MLflow: {outputs['MlflowTrackingServerArn']}")
|
mlflow_name = outputs.get("MlflowTrackingServerName", cfg.managed_mlflow_tracking_server_name)
|
||||||
|
CONSOLE.print(f"[green]✓[/green] MLflow: {mlflow_name}")
|
||||||
elif cfg.mlflow.mode is MlflowMode.existing:
|
elif cfg.mlflow.mode is MlflowMode.existing:
|
||||||
CONSOLE.print(f"[green]✓[/green] MLflow: {cfg.mlflow.tracking_server_name}")
|
CONSOLE.print(f"[green]✓[/green] MLflow: {cfg.mlflow.tracking_server_name}")
|
||||||
CONSOLE.print("\n[bold green]Infrastructure ready.[/bold green]")
|
CONSOLE.print("\n[bold green]Infrastructure ready.[/bold green]")
|
||||||
@@ -82,7 +88,7 @@ def setup(
|
|||||||
def status(config: str = CONFIG_OPT) -> None:
|
def status(config: str = CONFIG_OPT) -> None:
|
||||||
"""Show current infrastructure status."""
|
"""Show current infrastructure status."""
|
||||||
cfg = load_cfg(config)
|
cfg = load_cfg(config)
|
||||||
stack = cloudformation.stack_status(cfg.aws.region, cfg.aws.profile)
|
stack = cloudformation.stack_status(cfg.aws.region, cfg.aws.profile, cfg.infra.stack_name)
|
||||||
|
|
||||||
table = Table(title="Infrastructure Status")
|
table = Table(title="Infrastructure Status")
|
||||||
table.add_column("Resource", style="cyan")
|
table.add_column("Resource", style="cyan")
|
||||||
@@ -91,13 +97,13 @@ def status(config: str = CONFIG_OPT) -> None:
|
|||||||
table.add_column("ARN / URI")
|
table.add_column("ARN / URI")
|
||||||
|
|
||||||
if not stack:
|
if not stack:
|
||||||
table.add_row("CDK Stack", provisioning.STACK_NAME, "[red]missing[/red]", "-")
|
table.add_row("CDK Stack", cfg.infra.stack_name, "[red]missing[/red]", "-")
|
||||||
table.add_row("S3 Bucket", cfg.s3.bucket, "[red]unknown[/red]", "-")
|
table.add_row("S3 Bucket", cfg.s3.bucket, "[red]unknown[/red]", "-")
|
||||||
table.add_row("IAM Role", cfg.sagemaker.role_name, "[red]unknown[/red]", "-")
|
table.add_row("IAM Role", cfg.sagemaker.role_name, "[red]unknown[/red]", "-")
|
||||||
if cfg.mlflow.mode is not MlflowMode.disabled:
|
if cfg.mlflow.mode is not MlflowMode.disabled:
|
||||||
table.add_row(
|
table.add_row(
|
||||||
"MLflow",
|
"MLflow",
|
||||||
cfg.mlflow.tracking_server_name or "-",
|
cfg.effective_mlflow_tracking_server_name or "-",
|
||||||
"[red]unknown[/red]",
|
"[red]unknown[/red]",
|
||||||
"-",
|
"-",
|
||||||
)
|
)
|
||||||
@@ -114,14 +120,14 @@ def status(config: str = CONFIG_OPT) -> None:
|
|||||||
)
|
)
|
||||||
table.add_row(
|
table.add_row(
|
||||||
"IAM Role",
|
"IAM Role",
|
||||||
cfg.sagemaker.role_name,
|
_role_name(cfg.sagemaker.role_name, outputs.get("SageMakerRoleArn", "")),
|
||||||
"[green]managed[/green]",
|
"[green]managed[/green]",
|
||||||
outputs.get("SageMakerRoleArn", "-"),
|
outputs.get("SageMakerRoleArn", "-"),
|
||||||
)
|
)
|
||||||
if cfg.mlflow.mode is MlflowMode.create:
|
if cfg.mlflow.mode is MlflowMode.create:
|
||||||
table.add_row(
|
table.add_row(
|
||||||
"MLflow",
|
"MLflow",
|
||||||
cfg.mlflow.tracking_server_name or "-",
|
outputs.get("MlflowTrackingServerName", cfg.managed_mlflow_tracking_server_name),
|
||||||
"[green]managed[/green]",
|
"[green]managed[/green]",
|
||||||
outputs.get("MlflowTrackingServerArn", outputs.get("MlflowArtifactUri", "-")),
|
outputs.get("MlflowTrackingServerArn", outputs.get("MlflowArtifactUri", "-")),
|
||||||
)
|
)
|
||||||
@@ -156,10 +162,13 @@ def destroy(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Destroy the CDK stack."""
|
"""Destroy the CDK stack."""
|
||||||
cfg = _destroy_config(config)
|
cfg = _destroy_config(config)
|
||||||
|
stack_name = _destroy_stack_name(config, cfg)
|
||||||
|
bootstrap_qualifier = _destroy_bootstrap_qualifier(config, cfg)
|
||||||
|
toolkit_stack_name = _destroy_toolkit_stack_name(config, cfg)
|
||||||
|
|
||||||
if not yes and not delete_bucket_data:
|
if not yes and not delete_bucket_data:
|
||||||
typer.confirm(
|
typer.confirm(
|
||||||
f"Destroy CDK stack '{provisioning.STACK_NAME}' while retaining S3 bucket data?",
|
f"Destroy CDK stack '{stack_name}' while retaining S3 bucket data?",
|
||||||
abort=True,
|
abort=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -172,13 +181,17 @@ def destroy(
|
|||||||
provisioning.destroy(
|
provisioning.destroy(
|
||||||
profile=cfg.aws.profile,
|
profile=cfg.aws.profile,
|
||||||
account_id=account_id,
|
account_id=account_id,
|
||||||
|
stack_name=stack_name,
|
||||||
|
bootstrap_qualifier=bootstrap_qualifier,
|
||||||
|
toolkit_stack_name=toolkit_stack_name,
|
||||||
config_path=str(snapshot_path),
|
config_path=str(snapshot_path),
|
||||||
delete_bucket_data=delete_bucket_data,
|
delete_bucket_data=delete_bucket_data,
|
||||||
)
|
)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
CONSOLE.print(f"[red]{e}[/red]")
|
CONSOLE.print(f"[red]{e}[/red]")
|
||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
CONSOLE.print(f"[green]✓[/green] Destroyed stack: {provisioning.STACK_NAME}")
|
CONSOLE.print(f"[green]✓[/green] Destroyed stack: {stack_name}")
|
||||||
|
CONSOLE.print(f"[yellow]CDK bootstrap stack retained: {toolkit_stack_name}[/yellow]")
|
||||||
|
|
||||||
|
|
||||||
def _destroy_config(config_path: str) -> Config:
|
def _destroy_config(config_path: str) -> Config:
|
||||||
@@ -190,6 +203,14 @@ def _destroy_config(config_path: str) -> Config:
|
|||||||
return load_cfg(config_path)
|
return load_cfg(config_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _role_name(configured_name: str, role_arn: str) -> str:
|
||||||
|
if configured_name:
|
||||||
|
return configured_name
|
||||||
|
if role_arn:
|
||||||
|
return role_arn.rsplit("/", 1)[-1]
|
||||||
|
return "-"
|
||||||
|
|
||||||
|
|
||||||
def _destroy_account_id(config_path: str, cfg: Config) -> str:
|
def _destroy_account_id(config_path: str, cfg: Config) -> str:
|
||||||
config_dir = str(Path(config_path).parent)
|
config_dir = str(Path(config_path).parent)
|
||||||
state = read_infra_state(config_dir)
|
state = read_infra_state(config_dir)
|
||||||
@@ -197,3 +218,30 @@ def _destroy_account_id(config_path: str, cfg: Config) -> str:
|
|||||||
if account_id:
|
if account_id:
|
||||||
return str(account_id)
|
return str(account_id)
|
||||||
return identity.account_id(cfg.aws.region, cfg.aws.profile)
|
return identity.account_id(cfg.aws.region, cfg.aws.profile)
|
||||||
|
|
||||||
|
|
||||||
|
def _destroy_stack_name(config_path: str, cfg: Config) -> str:
|
||||||
|
config_dir = str(Path(config_path).parent)
|
||||||
|
state = read_infra_state(config_dir)
|
||||||
|
stack_name = state.get("stack_name")
|
||||||
|
if stack_name:
|
||||||
|
return str(stack_name)
|
||||||
|
return cfg.infra.stack_name
|
||||||
|
|
||||||
|
|
||||||
|
def _destroy_bootstrap_qualifier(config_path: str, cfg: Config) -> str:
|
||||||
|
config_dir = str(Path(config_path).parent)
|
||||||
|
state = read_infra_state(config_dir)
|
||||||
|
bootstrap_qualifier = state.get("bootstrap_qualifier")
|
||||||
|
if bootstrap_qualifier:
|
||||||
|
return str(bootstrap_qualifier)
|
||||||
|
return cfg.infra.effective_bootstrap_qualifier
|
||||||
|
|
||||||
|
|
||||||
|
def _destroy_toolkit_stack_name(config_path: str, cfg: Config) -> str:
|
||||||
|
config_dir = str(Path(config_path).parent)
|
||||||
|
state = read_infra_state(config_dir)
|
||||||
|
toolkit_stack_name = state.get("toolkit_stack_name")
|
||||||
|
if toolkit_stack_name:
|
||||||
|
return str(toolkit_stack_name)
|
||||||
|
return cfg.infra.effective_toolkit_stack_name
|
||||||
|
|||||||
40
src/commands/init.py
Normal file
40
src/commands/init.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import secrets
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import typer
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from src.commands.utils import CONSOLE
|
||||||
|
from src.config import GENERATED_STACK_PREFIX, Config, InfraConfig, S3Config
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def init(
|
||||||
|
output: str = typer.Option("config.yaml", help="Destination path for the config file"),
|
||||||
|
force: bool = typer.Option(False, "--force", "-f", help="Overwrite an existing config file"),
|
||||||
|
) -> None:
|
||||||
|
"""Write a starter config.yaml to the current directory."""
|
||||||
|
dest = Path(output)
|
||||||
|
if dest.exists() and not force:
|
||||||
|
CONSOLE.print(f"[yellow]{dest} already exists.[/yellow] Use --force to overwrite.")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
config = _new_isolated_config()
|
||||||
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
config_data = config.model_dump(mode="json")
|
||||||
|
config_data["sagemaker"].pop("role_name", None)
|
||||||
|
with open(dest, "w") as f:
|
||||||
|
yaml.safe_dump(config_data, f, sort_keys=False)
|
||||||
|
|
||||||
|
CONSOLE.print(f"[green]✓[/green] Config written to [bold]{dest}[/bold]")
|
||||||
|
CONSOLE.print("Edit [cyan]sagemaker.training.image_uri[/cyan] before running training commands.")
|
||||||
|
|
||||||
|
|
||||||
|
def _new_isolated_config() -> Config:
|
||||||
|
suffix = secrets.token_hex(6)
|
||||||
|
namespace = f"{GENERATED_STACK_PREFIX}{suffix}"
|
||||||
|
config = Config(infra=InfraConfig(stack_name=namespace))
|
||||||
|
config.s3 = S3Config(bucket=f"{namespace}-data")
|
||||||
|
return config
|
||||||
95
src/commands/mlflow.py
Normal file
95
src/commands/mlflow.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
import webbrowser
|
||||||
|
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from src import state as state_ops
|
||||||
|
from src.aws import mlflow as aws_mlflow
|
||||||
|
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
||||||
|
from src.config import MlflowMode
|
||||||
|
from src.tracking.upload import upload_training_metrics
|
||||||
|
|
||||||
|
app = typer.Typer(help="Manage MLflow tracking server access")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command(name="open")
|
||||||
|
def open_mlflow(config: str = CONFIG_OPT) -> None:
|
||||||
|
"""Open a presigned URL for the configured MLflow tracking server."""
|
||||||
|
cfg = load_cfg(config)
|
||||||
|
tracking_server_name = cfg.effective_mlflow_tracking_server_name
|
||||||
|
if not tracking_server_name:
|
||||||
|
CONSOLE.print("[red]MLflow is disabled in config.yaml.[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
url = aws_mlflow.create_presigned_tracking_server_url(
|
||||||
|
cfg.aws.region,
|
||||||
|
cfg.aws.profile,
|
||||||
|
tracking_server_name,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
CONSOLE.print("[yellow]Could not create a SageMaker MLflow UI URL.[/yellow]")
|
||||||
|
CONSOLE.print(f"Tracking server: [cyan]{tracking_server_name}[/cyan]")
|
||||||
|
CONSOLE.print(f"Reason: {e}")
|
||||||
|
CONSOLE.print(
|
||||||
|
"This command can create presigned URLs only for MLflow tracking servers managed by "
|
||||||
|
"Amazon SageMaker. If this is an external MLflow server, open it with that server's own URL."
|
||||||
|
)
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
CONSOLE.print(f"MLflow tracking server: [cyan]{tracking_server_name}[/cyan]")
|
||||||
|
CONSOLE.print(f"MLflow UI: {url}")
|
||||||
|
if webbrowser.open(url):
|
||||||
|
CONSOLE.print("[green]✓[/green] Opened MLflow UI in your browser.")
|
||||||
|
else:
|
||||||
|
CONSOLE.print("[yellow]Could not open a browser automatically. Open the URL above manually.[/yellow]")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command(name="upload-metrics")
|
||||||
|
def upload_metrics(
|
||||||
|
job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"),
|
||||||
|
force: bool = typer.Option(False, "--force", help="Upload again even if this job is marked as uploaded"),
|
||||||
|
config: str = CONFIG_OPT,
|
||||||
|
) -> None:
|
||||||
|
"""Upload a completed training job's metric history to MLflow."""
|
||||||
|
cfg = load_cfg(config)
|
||||||
|
if cfg.mlflow.mode is MlflowMode.disabled:
|
||||||
|
CONSOLE.print("[red]MLflow is disabled in config.yaml.[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
st = state_ops.store(config)
|
||||||
|
if not job_name:
|
||||||
|
job_name = st.get_last_training_job()
|
||||||
|
if not job_name:
|
||||||
|
CONSOLE.print(
|
||||||
|
"[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]"
|
||||||
|
)
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
if st.get_training_job(job_name).get("mlflow_metrics_uploaded") and not force:
|
||||||
|
CONSOLE.print(f"[green]✓[/green] Metrics already uploaded for [cyan]{job_name}[/cyan].")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = upload_training_metrics(
|
||||||
|
job_name=job_name,
|
||||||
|
config_path=config,
|
||||||
|
cfg=cfg,
|
||||||
|
force=force,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
if result.metrics_history_uploaded:
|
||||||
|
CONSOLE.print(f"[green]✓[/green] Uploaded training metrics for [cyan]{job_name}[/cyan].")
|
||||||
|
else:
|
||||||
|
CONSOLE.print(
|
||||||
|
f"[yellow]No training_metrics.json was found in the SageMaker model artifact for "
|
||||||
|
f"[cyan]{job_name}[/cyan]. Uploaded SageMaker final metrics only.[/yellow]"
|
||||||
|
)
|
||||||
|
CONSOLE.print(f"MLflow run: [cyan]{result.run_id}[/cyan]")
|
||||||
|
if result.registered_model_version:
|
||||||
|
CONSOLE.print(
|
||||||
|
f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] "
|
||||||
|
"([cyan]experiment-latest[/cyan])"
|
||||||
|
)
|
||||||
265
src/commands/train.py
Normal file
265
src/commands/train.py
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import typer
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
from src import state as state_ops
|
||||||
|
from src.aws import iam
|
||||||
|
from src.aws import sagemaker as sm_ops
|
||||||
|
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
||||||
|
from src.config import Config, MlflowMode
|
||||||
|
from src.infra.state import read_infra_state
|
||||||
|
from src.tracking.mlflow import MlflowTracker
|
||||||
|
from src.tracking.upload import upload_training_metrics
|
||||||
|
|
||||||
|
app = typer.Typer(help="Manage SageMaker training jobs")
|
||||||
|
|
||||||
|
_STATUS_COLOR = {
|
||||||
|
"Completed": "green",
|
||||||
|
"Failed": "red",
|
||||||
|
"InProgress": "yellow",
|
||||||
|
"Stopping": "yellow",
|
||||||
|
"Stopped": "dim",
|
||||||
|
}
|
||||||
|
_TERMINAL_STATUSES = {"Completed", "Failed", "Stopped"}
|
||||||
|
DEFAULT_POLL_INTERVAL_SECONDS = 30
|
||||||
|
|
||||||
|
|
||||||
|
def _tracker(cfg):
|
||||||
|
try:
|
||||||
|
return MlflowTracker.from_config(cfg)
|
||||||
|
except Exception as e:
|
||||||
|
CONSOLE.print(f"[red]MLflow setup failed: {e}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def _config_dir(config_path: str) -> str:
|
||||||
|
return str(Path(config_path).parent)
|
||||||
|
|
||||||
|
|
||||||
|
def _sagemaker_role_arn(config_path: str, cfg: Config) -> str:
|
||||||
|
state = read_infra_state(_config_dir(config_path))
|
||||||
|
role_arn = state.get("outputs", {}).get("SageMakerRoleArn")
|
||||||
|
if role_arn:
|
||||||
|
return str(role_arn)
|
||||||
|
if cfg.sagemaker.role_name:
|
||||||
|
role_arn = iam.get_role_arn(cfg.aws.profile, cfg.sagemaker.role_name)
|
||||||
|
if role_arn:
|
||||||
|
return role_arn
|
||||||
|
raise RuntimeError(f"IAM role '{cfg.sagemaker.role_name}' not found. Run 'qc-cli infra setup' first.")
|
||||||
|
raise RuntimeError("SageMaker role not found in infra state. Run 'qc-cli infra setup' first.")
|
||||||
|
|
||||||
|
|
||||||
|
def _print_training_status(status: sm_ops.TrainingJobStatus) -> None:
|
||||||
|
color = _STATUS_COLOR.get(status.status, "white")
|
||||||
|
CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]")
|
||||||
|
CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]")
|
||||||
|
if status.created:
|
||||||
|
CONSOLE.print(f"Created: {status.created}")
|
||||||
|
if status.model_artifacts:
|
||||||
|
CONSOLE.print(f"Artifacts: {status.model_artifacts}")
|
||||||
|
if status.failure_reason:
|
||||||
|
CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]")
|
||||||
|
|
||||||
|
|
||||||
|
def _wait_and_upload_metrics(
|
||||||
|
*,
|
||||||
|
job_name: str,
|
||||||
|
poll_interval: int,
|
||||||
|
config_path: str,
|
||||||
|
cfg: Config,
|
||||||
|
) -> None:
|
||||||
|
st = state_ops.store(config_path)
|
||||||
|
previous_status: str | None = None
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
training_status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
||||||
|
if training_status.status != previous_status:
|
||||||
|
color = _STATUS_COLOR.get(training_status.status, "white")
|
||||||
|
CONSOLE.print(
|
||||||
|
f"Job [cyan]{training_status.name}[/cyan]: "
|
||||||
|
f"[{color}]{training_status.status}[/{color}]"
|
||||||
|
)
|
||||||
|
previous_status = training_status.status
|
||||||
|
if training_status.status in _TERMINAL_STATUSES:
|
||||||
|
_print_training_status(training_status)
|
||||||
|
if training_status.status != "Completed":
|
||||||
|
raise typer.Exit(1)
|
||||||
|
try:
|
||||||
|
result = upload_training_metrics(
|
||||||
|
job_name=job_name,
|
||||||
|
config_path=config_path,
|
||||||
|
cfg=cfg,
|
||||||
|
)
|
||||||
|
if result.metrics_history_uploaded:
|
||||||
|
CONSOLE.print(
|
||||||
|
f"[green]✓[/green] Uploaded training metrics to MLflow run "
|
||||||
|
f"[cyan]{result.run_id}[/cyan]."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
CONSOLE.print(
|
||||||
|
"[yellow]No training_metrics.json was found in the SageMaker model artifact. "
|
||||||
|
"Uploaded SageMaker final metrics only.[/yellow]"
|
||||||
|
)
|
||||||
|
if result.registered_model_version:
|
||||||
|
CONSOLE.print(
|
||||||
|
f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] "
|
||||||
|
"([cyan]experiment-latest[/cyan])"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]")
|
||||||
|
CONSOLE.print(
|
||||||
|
f"Retry with [cyan]qc-cli mlflow upload-metrics {job_name}[/cyan]."
|
||||||
|
)
|
||||||
|
raise typer.Exit(1)
|
||||||
|
job_state = st.get_training_job(job_name)
|
||||||
|
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
|
||||||
|
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
|
||||||
|
return
|
||||||
|
time.sleep(poll_interval)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
CONSOLE.print("[yellow]Stopped waiting. The SageMaker training job is still running.[/yellow]")
|
||||||
|
raise typer.Exit(130)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def start(
|
||||||
|
upload_metrics: bool = typer.Option(
|
||||||
|
False,
|
||||||
|
"--upload-metrics",
|
||||||
|
help="Wait for completion, then upload training metrics to MLflow",
|
||||||
|
),
|
||||||
|
poll_interval: int = typer.Option(
|
||||||
|
DEFAULT_POLL_INTERVAL_SECONDS,
|
||||||
|
"--poll-interval",
|
||||||
|
min=1,
|
||||||
|
help="Seconds between status checks when --upload-metrics is used",
|
||||||
|
),
|
||||||
|
config: str = CONFIG_OPT,
|
||||||
|
) -> None:
|
||||||
|
"""Submit a SageMaker training job."""
|
||||||
|
cfg = load_cfg(config)
|
||||||
|
|
||||||
|
if upload_metrics and cfg.mlflow.mode is MlflowMode.disabled:
|
||||||
|
CONSOLE.print("[red]--upload-metrics requires MLflow to be enabled in config.yaml.[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
if not cfg.sagemaker.training.image_uri:
|
||||||
|
CONSOLE.print("[red]sagemaker.training.image_uri is required in config.yaml.[/red]")
|
||||||
|
CONSOLE.print(
|
||||||
|
"Find pre-built images at: "
|
||||||
|
"https://aws.github.io/deep-learning-containers/reference/available_images"
|
||||||
|
)
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
role_arn = _sagemaker_role_arn(config, cfg)
|
||||||
|
except RuntimeError as e:
|
||||||
|
CONSOLE.print(f"[red]{e}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
tracker = _tracker(cfg)
|
||||||
|
job_name = f"qc-cli-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
|
||||||
|
s3_train_uri = f"s3://{cfg.s3.bucket}/{cfg.s3.data_prefix}"
|
||||||
|
s3_output = f"s3://{cfg.s3.bucket}/{cfg.s3.model_prefix}"
|
||||||
|
|
||||||
|
CONSOLE.print(f"Submitting training job [cyan]{job_name}[/cyan]...")
|
||||||
|
training_job = sm_ops.TrainingJobRequest(
|
||||||
|
role_arn=role_arn,
|
||||||
|
image_uri=cfg.sagemaker.training.image_uri,
|
||||||
|
instance_type=cfg.sagemaker.training.instance_type,
|
||||||
|
instance_count=cfg.sagemaker.training.instance_count,
|
||||||
|
s3_train_uri=s3_train_uri,
|
||||||
|
s3_output_path=s3_output,
|
||||||
|
job_name=job_name,
|
||||||
|
hyperparameters=cfg.sagemaker.training.hyperparameters,
|
||||||
|
entry_point=cfg.sagemaker.training.entry_point,
|
||||||
|
source_dir=cfg.sagemaker.training.source_dir,
|
||||||
|
)
|
||||||
|
sm_ops.start_training_job(cfg.aws.boto3_session, training_job)
|
||||||
|
|
||||||
|
st = state_ops.store(config)
|
||||||
|
st.set_last_training_job(job_name)
|
||||||
|
try:
|
||||||
|
run_id = tracker.start_training_run(
|
||||||
|
training_job,
|
||||||
|
region=cfg.aws.region,
|
||||||
|
profile=cfg.aws.profile,
|
||||||
|
role_arn=role_arn,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
run_id = None
|
||||||
|
CONSOLE.print(f"[yellow]MLflow run creation failed: {e}[/yellow]")
|
||||||
|
CONSOLE.print(
|
||||||
|
"The SageMaker job is still running. Upload metrics after completion with "
|
||||||
|
f"[cyan]qc-cli mlflow upload-metrics {job_name}[/cyan]."
|
||||||
|
)
|
||||||
|
if run_id:
|
||||||
|
st.update_training_job(job_name, mlflow_run_id=run_id)
|
||||||
|
|
||||||
|
CONSOLE.print(f"[green]✓[/green] Job submitted: [bold]{job_name}[/bold]")
|
||||||
|
if run_id:
|
||||||
|
CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]")
|
||||||
|
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
|
||||||
|
if upload_metrics:
|
||||||
|
_wait_and_upload_metrics(
|
||||||
|
job_name=job_name,
|
||||||
|
poll_interval=poll_interval,
|
||||||
|
config_path=config,
|
||||||
|
cfg=cfg,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def status(
|
||||||
|
job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"),
|
||||||
|
config: str = CONFIG_OPT,
|
||||||
|
) -> None:
|
||||||
|
"""Show training job status."""
|
||||||
|
cfg = load_cfg(config)
|
||||||
|
st = state_ops.store(config)
|
||||||
|
|
||||||
|
if not job_name:
|
||||||
|
job_name = st.get_last_training_job()
|
||||||
|
if not job_name:
|
||||||
|
CONSOLE.print(
|
||||||
|
"[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]"
|
||||||
|
)
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
||||||
|
_print_training_status(status)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command(name="list")
|
||||||
|
def list_jobs(
|
||||||
|
limit: int = typer.Option(10, "--limit", "-n", help="Number of jobs to show"),
|
||||||
|
config: str = CONFIG_OPT,
|
||||||
|
) -> None:
|
||||||
|
"""List recent training jobs."""
|
||||||
|
cfg = load_cfg(config)
|
||||||
|
jobs = sm_ops.list_training_jobs(cfg.aws.boto3_session, max_results=limit)
|
||||||
|
|
||||||
|
if not jobs:
|
||||||
|
CONSOLE.print("[yellow]No training jobs found.[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
table = Table(title="Training Jobs")
|
||||||
|
table.add_column("Name", style="cyan")
|
||||||
|
table.add_column("Status")
|
||||||
|
table.add_column("Created")
|
||||||
|
|
||||||
|
for job in jobs:
|
||||||
|
status_value = str(job["TrainingJobStatus"])
|
||||||
|
color = _STATUS_COLOR.get(status_value, "white")
|
||||||
|
table.add_row(
|
||||||
|
str(job["TrainingJobName"]),
|
||||||
|
f"[{color}]{status_value}[/{color}]",
|
||||||
|
str(job.get("CreationTime", "")),
|
||||||
|
)
|
||||||
|
|
||||||
|
CONSOLE.print(table)
|
||||||
70
src/commands/upload.py
Normal file
70
src/commands/upload.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import typer
|
||||||
|
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
|
||||||
|
|
||||||
|
from src.aws import s3 as s3_ops
|
||||||
|
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def upload(
|
||||||
|
path: Path = typer.Argument(..., help="Local file or directory to upload"),
|
||||||
|
s3_key: str | None = typer.Option(None, "--s3-key", help="S3 key for file uploads"),
|
||||||
|
config: str = CONFIG_OPT,
|
||||||
|
) -> None:
|
||||||
|
"""Upload a local file or directory to S3."""
|
||||||
|
cfg = load_cfg(config)
|
||||||
|
|
||||||
|
if path.is_file():
|
||||||
|
key = s3_key or f"{cfg.s3.data_prefix.rstrip('/')}/{path.name}"
|
||||||
|
try:
|
||||||
|
with CONSOLE.status(f"Uploading {path.name}..."):
|
||||||
|
uri = s3_ops.upload_file(cfg.aws.region, cfg.aws.profile, cfg.s3.bucket, str(path), key)
|
||||||
|
except Exception as e:
|
||||||
|
CONSOLE.print(f"[red]Upload failed: {e}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
CONSOLE.print(f"[green]✓[/green] {path.name} -> {uri}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if path.is_dir():
|
||||||
|
if s3_key is not None:
|
||||||
|
CONSOLE.print("[red]--s3-key can only be used when uploading a single file.[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
files = [file for file in path.rglob("*") if file.is_file()]
|
||||||
|
if not files:
|
||||||
|
CONSOLE.print("[yellow]No files found in directory.[/yellow]")
|
||||||
|
raise typer.Exit(0)
|
||||||
|
|
||||||
|
prefix = cfg.s3.data_prefix
|
||||||
|
CONSOLE.print(f"Uploading {len(files)} files to s3://{cfg.s3.bucket}/{prefix.rstrip('/')}/")
|
||||||
|
try:
|
||||||
|
with Progress(
|
||||||
|
SpinnerColumn(),
|
||||||
|
TextColumn("[progress.description]{task.description}"),
|
||||||
|
BarColumn(),
|
||||||
|
TaskProgressColumn(),
|
||||||
|
console=CONSOLE,
|
||||||
|
) as progress:
|
||||||
|
task = progress.add_task("Uploading...", total=len(files))
|
||||||
|
count = s3_ops.upload_dir(
|
||||||
|
cfg.aws.region,
|
||||||
|
cfg.aws.profile,
|
||||||
|
cfg.s3.bucket,
|
||||||
|
str(path),
|
||||||
|
prefix,
|
||||||
|
on_progress=lambda: progress.advance(task),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
CONSOLE.print(f"[red]Upload failed: {e}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
CONSOLE.print(f"[green]✓[/green] Uploaded {count} files to s3://{cfg.s3.bucket}/{prefix.rstrip('/')}/")
|
||||||
|
return
|
||||||
|
|
||||||
|
CONSOLE.print(f"[red]Path not found: {path}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
@@ -14,7 +14,7 @@ def load_config(path: str = "config.yaml") -> Config:
|
|||||||
config_path = Path(path)
|
config_path = Path(path)
|
||||||
if not config_path.exists():
|
if not config_path.exists():
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"Config file not found: {config_path}. Run 'qai-cli init' to create one."
|
f"Config file not found: {config_path}. Run 'qc-cli init' to create one."
|
||||||
)
|
)
|
||||||
with open(config_path) as f:
|
with open(config_path) as f:
|
||||||
data = yaml.safe_load(f)
|
data = yaml.safe_load(f)
|
||||||
|
|||||||
@@ -1,30 +1,68 @@
|
|||||||
from enum import Enum
|
import re
|
||||||
from typing import Any, Literal
|
from enum import StrEnum
|
||||||
|
from typing import Any, Literal, TypedDict
|
||||||
|
|
||||||
from mypy_boto3_s3.literals import BucketLocationConstraintType
|
from mypy_boto3_s3.literals import BucketLocationConstraintType
|
||||||
from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
|
from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
|
from qai_hub.client import Device
|
||||||
|
|
||||||
|
|
||||||
class MlflowMode(str, Enum):
|
class MlflowMode(StrEnum):
|
||||||
disabled = "disabled"
|
disabled = "disabled"
|
||||||
create = "create"
|
create = "create"
|
||||||
existing = "existing"
|
existing = "existing"
|
||||||
|
|
||||||
|
|
||||||
class MlflowServerSize(str, Enum):
|
class MlflowServerSize(StrEnum):
|
||||||
small = "Small"
|
small = "Small"
|
||||||
medium = "Medium"
|
medium = "Medium"
|
||||||
large = "Large"
|
large = "Large"
|
||||||
|
|
||||||
|
|
||||||
|
class Boto3SessionKwargs(TypedDict):
|
||||||
|
profile_name: str
|
||||||
|
region_name: str
|
||||||
|
|
||||||
|
|
||||||
class AwsConfig(BaseModel):
|
class AwsConfig(BaseModel):
|
||||||
region: BucketLocationConstraintType | Literal["us-east-1"] = "us-east-1"
|
region: BucketLocationConstraintType | Literal["us-east-1"] = "us-east-1"
|
||||||
profile: str = "default"
|
profile: str = "default"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def boto3_session(self) -> Boto3SessionKwargs:
|
||||||
|
return {"profile_name": self.profile, "region_name": self.region}
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_BOOTSTRAP_QUALIFIER = "hnb659fds"
|
||||||
|
GENERATED_STACK_PREFIX = "qc-cli-mlops-"
|
||||||
|
|
||||||
|
|
||||||
|
class InfraConfig(BaseModel):
|
||||||
|
stack_name: str
|
||||||
|
|
||||||
|
@property
|
||||||
|
def effective_bootstrap_qualifier(self) -> str:
|
||||||
|
sanitized = re.sub(r"[^a-z0-9]", "", self.stack_name.lower())
|
||||||
|
if not sanitized:
|
||||||
|
return DEFAULT_BOOTSTRAP_QUALIFIER
|
||||||
|
if self.stack_name.startswith(GENERATED_STACK_PREFIX):
|
||||||
|
suffix = re.sub(r"[^a-z0-9]", "", self.stack_name.removeprefix(GENERATED_STACK_PREFIX).lower())
|
||||||
|
if suffix:
|
||||||
|
return f"q{suffix}"[:10]
|
||||||
|
return f"q{sanitized}"[:10]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def effective_toolkit_stack_name(self) -> str:
|
||||||
|
if self.stack_name.startswith(GENERATED_STACK_PREFIX):
|
||||||
|
suffix = re.sub(r"[^A-Za-z0-9-]", "", self.stack_name.removeprefix(GENERATED_STACK_PREFIX))
|
||||||
|
if suffix:
|
||||||
|
return f"{self.stack_name}-bootstrap"
|
||||||
|
return f"{self.stack_name}-bootstrap"
|
||||||
|
|
||||||
|
|
||||||
class S3Config(BaseModel):
|
class S3Config(BaseModel):
|
||||||
bucket: str = "my-onnx-bucket"
|
bucket: str = "my-qc-mlops-bucket"
|
||||||
data_prefix: str = "data/"
|
data_prefix: str = "data/"
|
||||||
model_prefix: str = "models/"
|
model_prefix: str = "models/"
|
||||||
|
|
||||||
@@ -39,13 +77,35 @@ class TrainingConfig(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class SageMakerConfig(BaseModel):
|
class SageMakerConfig(BaseModel):
|
||||||
role_name: str = "qai-cli-sagemaker-role"
|
role_name: str = ""
|
||||||
training: TrainingConfig = Field(default_factory=TrainingConfig)
|
training: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||||
|
|
||||||
|
|
||||||
|
class AIHubConfig(BaseModel):
|
||||||
|
device: Device = Field(default_factory=lambda: Device("Samsung Galaxy S25 (Family)"))
|
||||||
|
target_runtime: str = "tflite"
|
||||||
|
input_specs: dict[str, tuple[list[int], str]] = Field(default_factory=dict)
|
||||||
|
job_name: str | None = None
|
||||||
|
model_name: str | None = None
|
||||||
|
compile_options: str | None = None
|
||||||
|
profile_options: str | None = None
|
||||||
|
quantize_options: str | None = None
|
||||||
|
output_dir: str = "build/qai-hub"
|
||||||
|
|
||||||
|
@field_validator("device", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def parse_device(cls, value: Any) -> Any:
|
||||||
|
if isinstance(value, str):
|
||||||
|
return Device(value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class MlflowConfig(BaseModel):
|
class MlflowConfig(BaseModel):
|
||||||
mode: MlflowMode = MlflowMode.disabled
|
mode: MlflowMode = MlflowMode.disabled
|
||||||
tracking_server_name: str | None = None
|
tracking_server_name: str | None = None
|
||||||
|
experiment_name: str = "qc-cli-training"
|
||||||
|
registered_model_name: str = "qc-cli-model"
|
||||||
|
register_trained_models: bool = True
|
||||||
artifact_prefix: str = "mlflow/"
|
artifact_prefix: str = "mlflow/"
|
||||||
tracking_server_size: MlflowServerSize = MlflowServerSize.small
|
tracking_server_size: MlflowServerSize = MlflowServerSize.small
|
||||||
mlflow_version: str | None = None
|
mlflow_version: str | None = None
|
||||||
@@ -54,13 +114,27 @@ class MlflowConfig(BaseModel):
|
|||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def require_tracking_server_name(self) -> "MlflowConfig":
|
def require_tracking_server_name(self) -> "MlflowConfig":
|
||||||
if self.mode in {MlflowMode.create, MlflowMode.existing} and not self.tracking_server_name:
|
if self.mode is MlflowMode.existing and not self.tracking_server_name:
|
||||||
raise ValueError("mlflow.tracking_server_name is required when mlflow.mode is create or existing")
|
raise ValueError("mlflow.tracking_server_name is required when mlflow.mode is existing")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseModel):
|
class Config(BaseModel):
|
||||||
|
infra: InfraConfig
|
||||||
aws: AwsConfig = Field(default_factory=AwsConfig)
|
aws: AwsConfig = Field(default_factory=AwsConfig)
|
||||||
s3: S3Config = Field(default_factory=S3Config)
|
s3: S3Config = Field(default_factory=S3Config)
|
||||||
sagemaker: SageMakerConfig = Field(default_factory=SageMakerConfig)
|
sagemaker: SageMakerConfig = Field(default_factory=SageMakerConfig)
|
||||||
|
aihub: AIHubConfig = Field(default_factory=AIHubConfig)
|
||||||
mlflow: MlflowConfig = Field(default_factory=MlflowConfig)
|
mlflow: MlflowConfig = Field(default_factory=MlflowConfig)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def managed_mlflow_tracking_server_name(self) -> str:
|
||||||
|
return f"{self.infra.stack_name}-mlflow"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def effective_mlflow_tracking_server_name(self) -> str | None:
|
||||||
|
if self.mlflow.mode is MlflowMode.disabled:
|
||||||
|
return None
|
||||||
|
if self.mlflow.mode is MlflowMode.existing:
|
||||||
|
return self.mlflow.tracking_server_name
|
||||||
|
return self.managed_mlflow_tracking_server_name
|
||||||
|
|||||||
@@ -5,17 +5,27 @@ from typing import Any
|
|||||||
|
|
||||||
from src.infra.state import state_path, write_infra_state
|
from src.infra.state import state_path, write_infra_state
|
||||||
|
|
||||||
STACK_NAME = "QaiCliStack"
|
|
||||||
|
|
||||||
|
|
||||||
def bootstrap(
|
def bootstrap(
|
||||||
*,
|
*,
|
||||||
profile: str,
|
profile: str,
|
||||||
account_id: str,
|
account_id: str,
|
||||||
region: str,
|
region: str,
|
||||||
|
bootstrap_qualifier: str,
|
||||||
|
toolkit_stack_name: str,
|
||||||
cloudformation_execution_policy: str | None = None,
|
cloudformation_execution_policy: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
cmd = ["cdk", "bootstrap", f"aws://{account_id}/{region}", "--profile", profile]
|
cmd = [
|
||||||
|
"cdk",
|
||||||
|
"bootstrap",
|
||||||
|
f"aws://{account_id}/{region}",
|
||||||
|
"--profile",
|
||||||
|
profile,
|
||||||
|
"--qualifier",
|
||||||
|
bootstrap_qualifier,
|
||||||
|
"--toolkit-stack-name",
|
||||||
|
toolkit_stack_name,
|
||||||
|
]
|
||||||
if cloudformation_execution_policy:
|
if cloudformation_execution_policy:
|
||||||
cmd.extend(["--cloudformation-execution-policies", cloudformation_execution_policy])
|
cmd.extend(["--cloudformation-execution-policies", cloudformation_execution_policy])
|
||||||
_run(cmd)
|
_run(cmd)
|
||||||
@@ -26,6 +36,9 @@ def deploy(
|
|||||||
profile: str,
|
profile: str,
|
||||||
account_id: str,
|
account_id: str,
|
||||||
region: str,
|
region: str,
|
||||||
|
stack_name: str,
|
||||||
|
bootstrap_qualifier: str,
|
||||||
|
toolkit_stack_name: str,
|
||||||
config_path: str,
|
config_path: str,
|
||||||
config_dir: str,
|
config_dir: str,
|
||||||
config_snapshot: dict[str, Any],
|
config_snapshot: dict[str, Any],
|
||||||
@@ -35,19 +48,24 @@ def deploy(
|
|||||||
"deploy",
|
"deploy",
|
||||||
profile=profile,
|
profile=profile,
|
||||||
account_id=account_id,
|
account_id=account_id,
|
||||||
|
stack_name=stack_name,
|
||||||
|
bootstrap_qualifier=bootstrap_qualifier,
|
||||||
|
toolkit_stack_name=toolkit_stack_name,
|
||||||
config_path=config_path,
|
config_path=config_path,
|
||||||
delete_bucket_data=False,
|
delete_bucket_data=False,
|
||||||
) + ["--require-approval", "never", "--outputs-file", str(outputs_file)]
|
) + ["--require-approval", "never", "--outputs-file", str(outputs_file)]
|
||||||
_run(cmd)
|
_run(cmd)
|
||||||
|
|
||||||
outputs = _read_outputs(outputs_file)
|
outputs = _read_outputs(outputs_file, stack_name)
|
||||||
state = {
|
state = {
|
||||||
"stack_name": STACK_NAME,
|
"stack_name": stack_name,
|
||||||
"aws": {
|
"aws": {
|
||||||
"account_id": account_id,
|
"account_id": account_id,
|
||||||
"region": region,
|
"region": region,
|
||||||
"profile": profile,
|
"profile": profile,
|
||||||
},
|
},
|
||||||
|
"bootstrap_qualifier": bootstrap_qualifier,
|
||||||
|
"toolkit_stack_name": toolkit_stack_name,
|
||||||
"config": config_snapshot,
|
"config": config_snapshot,
|
||||||
"outputs": outputs,
|
"outputs": outputs,
|
||||||
}
|
}
|
||||||
@@ -59,6 +77,9 @@ def destroy(
|
|||||||
*,
|
*,
|
||||||
profile: str,
|
profile: str,
|
||||||
account_id: str,
|
account_id: str,
|
||||||
|
stack_name: str,
|
||||||
|
bootstrap_qualifier: str,
|
||||||
|
toolkit_stack_name: str,
|
||||||
config_path: str,
|
config_path: str,
|
||||||
delete_bucket_data: bool,
|
delete_bucket_data: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -67,6 +88,9 @@ def destroy(
|
|||||||
"deploy",
|
"deploy",
|
||||||
profile=profile,
|
profile=profile,
|
||||||
account_id=account_id,
|
account_id=account_id,
|
||||||
|
stack_name=stack_name,
|
||||||
|
bootstrap_qualifier=bootstrap_qualifier,
|
||||||
|
toolkit_stack_name=toolkit_stack_name,
|
||||||
config_path=config_path,
|
config_path=config_path,
|
||||||
delete_bucket_data=True,
|
delete_bucket_data=True,
|
||||||
) + ["--require-approval", "never"]
|
) + ["--require-approval", "never"]
|
||||||
@@ -76,6 +100,9 @@ def destroy(
|
|||||||
"destroy",
|
"destroy",
|
||||||
profile=profile,
|
profile=profile,
|
||||||
account_id=account_id,
|
account_id=account_id,
|
||||||
|
stack_name=stack_name,
|
||||||
|
bootstrap_qualifier=bootstrap_qualifier,
|
||||||
|
toolkit_stack_name=toolkit_stack_name,
|
||||||
config_path=config_path,
|
config_path=config_path,
|
||||||
delete_bucket_data=delete_bucket_data,
|
delete_bucket_data=delete_bucket_data,
|
||||||
) + ["--force"]
|
) + ["--force"]
|
||||||
@@ -87,26 +114,35 @@ def _cdk_cmd(
|
|||||||
*,
|
*,
|
||||||
profile: str,
|
profile: str,
|
||||||
account_id: str,
|
account_id: str,
|
||||||
|
stack_name: str,
|
||||||
|
bootstrap_qualifier: str,
|
||||||
|
toolkit_stack_name: str,
|
||||||
config_path: str,
|
config_path: str,
|
||||||
delete_bucket_data: bool,
|
delete_bucket_data: bool,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
cmd = [
|
cmd = [
|
||||||
"cdk",
|
"cdk",
|
||||||
action,
|
action,
|
||||||
STACK_NAME,
|
stack_name,
|
||||||
"--app",
|
"--app",
|
||||||
"python app.py",
|
"python app.py",
|
||||||
"--profile",
|
"--profile",
|
||||||
profile,
|
profile,
|
||||||
|
]
|
||||||
|
if action == "deploy":
|
||||||
|
cmd.extend(["--toolkit-stack-name", toolkit_stack_name])
|
||||||
|
cmd.extend([
|
||||||
"-c",
|
"-c",
|
||||||
f"account_id={account_id}",
|
f"account_id={account_id}",
|
||||||
"-c",
|
"-c",
|
||||||
f"config={config_path}",
|
f"config={config_path}",
|
||||||
"-c",
|
"-c",
|
||||||
f"stack_name={STACK_NAME}",
|
f"stack_name={stack_name}",
|
||||||
|
"-c",
|
||||||
|
f"bootstrap_qualifier={bootstrap_qualifier}",
|
||||||
"-c",
|
"-c",
|
||||||
f"delete_bucket_data={str(delete_bucket_data).lower()}",
|
f"delete_bucket_data={str(delete_bucket_data).lower()}",
|
||||||
]
|
])
|
||||||
return cmd
|
return cmd
|
||||||
|
|
||||||
|
|
||||||
@@ -119,9 +155,9 @@ def _run(cmd: list[str]) -> None:
|
|||||||
raise RuntimeError(f"CDK command failed with exit code {e.returncode}.") from e
|
raise RuntimeError(f"CDK command failed with exit code {e.returncode}.") from e
|
||||||
|
|
||||||
|
|
||||||
def _read_outputs(path: Path) -> dict[str, str]:
|
def _read_outputs(path: Path, stack_name: str) -> dict[str, str]:
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
return {}
|
return {}
|
||||||
with open(path) as f:
|
with open(path) as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
return data.get(STACK_NAME, {})
|
return data.get(stack_name, {})
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from constructs import Construct
|
|||||||
from src.config import Config, MlflowMode
|
from src.config import Config, MlflowMode
|
||||||
|
|
||||||
|
|
||||||
class QaiStack(Stack):
|
class QCStack(Stack):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
scope: Construct,
|
scope: Construct,
|
||||||
@@ -34,7 +34,7 @@ class QaiStack(Stack):
|
|||||||
role = iam.CfnRole(
|
role = iam.CfnRole(
|
||||||
self,
|
self,
|
||||||
"SageMakerRole",
|
"SageMakerRole",
|
||||||
role_name=config.sagemaker.role_name,
|
role_name=config.sagemaker.role_name or None,
|
||||||
assume_role_policy_document=self._sagemaker_trust_policy(),
|
assume_role_policy_document=self._sagemaker_trust_policy(),
|
||||||
managed_policy_arns=[
|
managed_policy_arns=[
|
||||||
f"arn:{self.partition}:iam::aws:policy/AmazonSageMakerFullAccess",
|
f"arn:{self.partition}:iam::aws:policy/AmazonSageMakerFullAccess",
|
||||||
@@ -74,6 +74,7 @@ class QaiStack(Stack):
|
|||||||
CfnOutput(self, "SageMakerRoleArn", value=role.attr_arn)
|
CfnOutput(self, "SageMakerRoleArn", value=role.attr_arn)
|
||||||
|
|
||||||
if config.mlflow.mode is MlflowMode.create:
|
if config.mlflow.mode is MlflowMode.create:
|
||||||
|
tracking_server_name = config.managed_mlflow_tracking_server_name
|
||||||
artifact_prefix = config.mlflow.artifact_prefix.strip("/")
|
artifact_prefix = config.mlflow.artifact_prefix.strip("/")
|
||||||
artifact_uri = (
|
artifact_uri = (
|
||||||
f"s3://{data_bucket.bucket_name}/{artifact_prefix}/"
|
f"s3://{data_bucket.bucket_name}/{artifact_prefix}/"
|
||||||
@@ -145,14 +146,14 @@ class QaiStack(Stack):
|
|||||||
"MlflowTrackingServer",
|
"MlflowTrackingServer",
|
||||||
artifact_store_uri=artifact_uri,
|
artifact_store_uri=artifact_uri,
|
||||||
role_arn=mlflow_role.attr_arn,
|
role_arn=mlflow_role.attr_arn,
|
||||||
tracking_server_name=config.mlflow.tracking_server_name or "",
|
tracking_server_name=tracking_server_name,
|
||||||
automatic_model_registration=config.mlflow.automatic_model_registration,
|
automatic_model_registration=config.mlflow.automatic_model_registration,
|
||||||
mlflow_version=config.mlflow.mlflow_version,
|
mlflow_version=config.mlflow.mlflow_version,
|
||||||
tracking_server_size=config.mlflow.tracking_server_size.value,
|
tracking_server_size=config.mlflow.tracking_server_size.value,
|
||||||
weekly_maintenance_window_start=config.mlflow.weekly_maintenance_window_start,
|
weekly_maintenance_window_start=config.mlflow.weekly_maintenance_window_start,
|
||||||
)
|
)
|
||||||
|
|
||||||
CfnOutput(self, "MlflowTrackingServerName", value=config.mlflow.tracking_server_name or "")
|
CfnOutput(self, "MlflowTrackingServerName", value=tracking_server_name)
|
||||||
CfnOutput(self, "MlflowTrackingServerArn", value=tracking_server.attr_tracking_server_arn)
|
CfnOutput(self, "MlflowTrackingServerArn", value=tracking_server.attr_tracking_server_arn)
|
||||||
CfnOutput(self, "MlflowArtifactUri", value=artifact_uri)
|
CfnOutput(self, "MlflowArtifactUri", value=artifact_uri)
|
||||||
CfnOutput(self, "MlflowRoleArn", value=mlflow_role.attr_arn)
|
CfnOutput(self, "MlflowRoleArn", value=mlflow_role.attr_arn)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import json
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
INFRA_STATE_FILE = ".qai-cli-infra.json"
|
INFRA_STATE_FILE = ".qc-cli-infra.json"
|
||||||
|
|
||||||
|
|
||||||
def state_path(config_dir: str) -> Path:
|
def state_path(config_dir: str) -> Path:
|
||||||
|
|||||||
39
src/main.py
39
src/main.py
@@ -1,39 +1,14 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
import yaml
|
|
||||||
from rich.console import Console
|
|
||||||
|
|
||||||
from src.commands import infra
|
from src.commands import ai_hub, infra, init, mlflow, train, upload
|
||||||
from src.config import Config
|
|
||||||
|
|
||||||
app = typer.Typer(
|
app = typer.Typer(
|
||||||
help="qai-cli: End-to-end model managment for Qualcomm AI Hub.",
|
help="qc-cli: End-to-end model managment for Qualcomm AI Hub.",
|
||||||
no_args_is_help=True,
|
no_args_is_help=True,
|
||||||
)
|
)
|
||||||
|
app.add_typer(init.app)
|
||||||
|
app.add_typer(upload.app)
|
||||||
|
app.add_typer(mlflow.app, name="mlflow")
|
||||||
app.add_typer(infra.app, name="infra")
|
app.add_typer(infra.app, name="infra")
|
||||||
|
app.add_typer(train.app, name="train")
|
||||||
console = Console()
|
app.add_typer(ai_hub.app, name="ai-hub")
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def init(
|
|
||||||
output: str = typer.Option("config.yaml", help="Destination path for the config file"),
|
|
||||||
force: bool = typer.Option(False, "--force", "-f", help="Overwrite an existing config file"),
|
|
||||||
) -> None:
|
|
||||||
"""Write a starter config.yaml to the current directory."""
|
|
||||||
dest = Path(output)
|
|
||||||
if dest.exists() and not force:
|
|
||||||
console.print(f"[yellow]{dest} already exists.[/yellow] Use --force to overwrite.")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
config = Config()
|
|
||||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(dest, "w") as f:
|
|
||||||
yaml.safe_dump(config.model_dump(mode="json"), f, sort_keys=False)
|
|
||||||
|
|
||||||
console.print(f"[green]✓[/green] Config written to [bold]{dest}[/bold]")
|
|
||||||
console.print(
|
|
||||||
"Edit it (especially [cyan]s3.bucket[/cyan] and [cyan]sagemaker.training.image_uri[/cyan]) "
|
|
||||||
"before running other commands."
|
|
||||||
)
|
|
||||||
|
|||||||
0
src/qualcomm/__init__.py
Normal file
0
src/qualcomm/__init__.py
Normal file
114
src/qualcomm/aihub_jobs.py
Normal file
114
src/qualcomm/aihub_jobs.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
|
import qai_hub.hub as hub
|
||||||
|
from qai_hub.client import CompileJob, Device, InferenceJob, Model, ProfileJob, QuantizeDtype, QuantizeJob
|
||||||
|
|
||||||
|
|
||||||
|
class ModelJobResult(TypedDict):
|
||||||
|
job: CompileJob | QuantizeJob
|
||||||
|
job_id: str
|
||||||
|
model: Model
|
||||||
|
model_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceJobResult(TypedDict):
|
||||||
|
job: InferenceJob
|
||||||
|
job_id: str
|
||||||
|
outputs: Any
|
||||||
|
|
||||||
|
|
||||||
|
class ProfileJobResult(TypedDict):
|
||||||
|
job: ProfileJob
|
||||||
|
job_id: str
|
||||||
|
|
||||||
|
|
||||||
|
def _dataset_entries(inputs: dict[str, Any]) -> dict[str, list[Any]]:
|
||||||
|
return {name: value if isinstance(value, list) else [value] for name, value in inputs.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def submit_compile_job(
|
||||||
|
model: Model,
|
||||||
|
device: Device,
|
||||||
|
input_specs: dict[str, tuple[tuple[int, ...], str]],
|
||||||
|
target_runtime: str,
|
||||||
|
options: str | None = None,
|
||||||
|
job_name: str | None = None,
|
||||||
|
) -> ModelJobResult:
|
||||||
|
compile_options = f"--target_runtime {target_runtime}"
|
||||||
|
if options:
|
||||||
|
compile_options = f"{compile_options} {options}"
|
||||||
|
|
||||||
|
job = hub.submit_compile_job(
|
||||||
|
model=model,
|
||||||
|
device=device,
|
||||||
|
name=job_name,
|
||||||
|
input_specs=input_specs,
|
||||||
|
options=compile_options,
|
||||||
|
)
|
||||||
|
target_model = job.get_target_model()
|
||||||
|
if target_model is None:
|
||||||
|
raise RuntimeError(f"Compile job {job.job_id} did not produce a target model.")
|
||||||
|
return {"job": job, "job_id": str(job.job_id), "model": target_model, "model_id": str(target_model.model_id)}
|
||||||
|
|
||||||
|
|
||||||
|
def submit_inference_job(
|
||||||
|
model: Model,
|
||||||
|
device: Device,
|
||||||
|
inputs: dict[str, Any],
|
||||||
|
output_dir: str | Path,
|
||||||
|
job_name: str | None = None,
|
||||||
|
) -> InferenceJobResult:
|
||||||
|
job = hub.submit_inference_job(
|
||||||
|
model=model,
|
||||||
|
device=device,
|
||||||
|
inputs=_dataset_entries(inputs),
|
||||||
|
name=job_name,
|
||||||
|
)
|
||||||
|
out = Path(output_dir)
|
||||||
|
out.mkdir(parents=True, exist_ok=True)
|
||||||
|
data = job.download_output_data(str(out))
|
||||||
|
return {"job": job, "job_id": str(job.job_id), "outputs": data}
|
||||||
|
|
||||||
|
|
||||||
|
def submit_profile_job(
|
||||||
|
model: Model,
|
||||||
|
device: Device,
|
||||||
|
options: str | None = None,
|
||||||
|
job_name: str | None = None,
|
||||||
|
) -> ProfileJobResult:
|
||||||
|
job = hub.submit_profile_job(
|
||||||
|
model=model,
|
||||||
|
device=device,
|
||||||
|
name=job_name,
|
||||||
|
options=options or "",
|
||||||
|
)
|
||||||
|
return {"job": job, "job_id": str(job.job_id)}
|
||||||
|
|
||||||
|
|
||||||
|
def submit_quantize_job(
|
||||||
|
model: Model,
|
||||||
|
calibration_data: dict[str, Any],
|
||||||
|
options: str | None = None,
|
||||||
|
job_name: str | None = None,
|
||||||
|
) -> ModelJobResult:
|
||||||
|
job = hub.submit_quantize_job(
|
||||||
|
model=model,
|
||||||
|
calibration_data=_dataset_entries(calibration_data),
|
||||||
|
weights_dtype=QuantizeDtype.INT8,
|
||||||
|
activations_dtype=QuantizeDtype.INT8,
|
||||||
|
name=job_name,
|
||||||
|
options=options or "",
|
||||||
|
)
|
||||||
|
target_model = job.get_target_model()
|
||||||
|
if target_model is None:
|
||||||
|
raise RuntimeError(f"Quantize job {job.job_id} did not produce a target model.")
|
||||||
|
return {"job": job, "job_id": str(job.job_id), "model": target_model, "model_id": str(target_model.model_id)}
|
||||||
|
|
||||||
|
|
||||||
|
def download_model(model_id: str, output_path: str | Path) -> str:
|
||||||
|
dest = Path(output_path)
|
||||||
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
model = hub.get_model(model_id)
|
||||||
|
result = model.download(str(dest))
|
||||||
|
return str(result or dest)
|
||||||
83
src/qualcomm/artifacts.py
Normal file
83
src/qualcomm/artifacts.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
import tarfile
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from src.aws import s3 as s3_ops
|
||||||
|
from src.aws import sagemaker as sm_ops
|
||||||
|
from src.config import Config
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ResolvedOnnx:
|
||||||
|
onnx_path: Path
|
||||||
|
model_artifact: str | None
|
||||||
|
run_name: str
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_extract(tar: tarfile.TarFile, dest: Path) -> None:
|
||||||
|
dest_root = dest.resolve()
|
||||||
|
for member in tar.getmembers():
|
||||||
|
target = (dest / member.name).resolve()
|
||||||
|
if dest_root != target and dest_root not in target.parents:
|
||||||
|
raise ValueError(f"Unsafe tar member path: {member.name}")
|
||||||
|
tar.extractall(dest, filter="data")
|
||||||
|
|
||||||
|
|
||||||
|
def _find_onnx(root: Path, explicit: str | None = None) -> Path:
|
||||||
|
if explicit:
|
||||||
|
p = Path(explicit)
|
||||||
|
if not p.is_absolute():
|
||||||
|
p = root / p
|
||||||
|
if not p.exists():
|
||||||
|
raise FileNotFoundError(f"ONNX file not found: {p}")
|
||||||
|
return p
|
||||||
|
|
||||||
|
matches = sorted(root.rglob("model.onnx"))
|
||||||
|
if not matches:
|
||||||
|
matches = sorted(root.rglob("*.onnx"))
|
||||||
|
if not matches:
|
||||||
|
raise FileNotFoundError(f"No ONNX file found under {root}")
|
||||||
|
if len(matches) > 1:
|
||||||
|
joined = ", ".join(str(p.relative_to(root)) for p in matches)
|
||||||
|
raise ValueError(f"Multiple ONNX files found ({joined}). Pass --onnx-path.")
|
||||||
|
return matches[0]
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_onnx(
|
||||||
|
cfg: Config,
|
||||||
|
output_dir: str,
|
||||||
|
from_job: str | None = None,
|
||||||
|
model_s3_uri: str | None = None,
|
||||||
|
onnx_path: str | None = None,
|
||||||
|
last_training_job: str | None = None,
|
||||||
|
) -> ResolvedOnnx:
|
||||||
|
if onnx_path:
|
||||||
|
path = Path(onnx_path)
|
||||||
|
if path.exists():
|
||||||
|
return ResolvedOnnx(onnx_path=path, model_artifact=None, run_name=path.stem)
|
||||||
|
|
||||||
|
job = from_job or last_training_job
|
||||||
|
artifact = model_s3_uri
|
||||||
|
if not artifact:
|
||||||
|
if not job:
|
||||||
|
raise ValueError("No model source found. Pass --onnx-path, --model-s3-uri, --from-job, or run training first.")
|
||||||
|
artifact = sm_ops.get_model_artifacts(cfg.aws.region, cfg.aws.profile, job)
|
||||||
|
|
||||||
|
run_name = job or Path(artifact).name.removesuffix(".tar.gz").replace("/", "-")
|
||||||
|
root = Path(output_dir) / run_name / "source"
|
||||||
|
tar_path = root / "model.tar.gz"
|
||||||
|
s3_ops.download_file(cfg.aws.region, cfg.aws.profile, artifact, str(tar_path))
|
||||||
|
|
||||||
|
extract_dir = root / "extracted"
|
||||||
|
extract_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
try:
|
||||||
|
with tarfile.open(tar_path, "r:gz") as tar:
|
||||||
|
_safe_extract(tar, extract_dir)
|
||||||
|
except tarfile.TarError as e:
|
||||||
|
raise ValueError(f"Invalid model tarball: {tar_path}") from e
|
||||||
|
|
||||||
|
return ResolvedOnnx(
|
||||||
|
onnx_path=_find_onnx(extract_dir, onnx_path),
|
||||||
|
model_artifact=artifact,
|
||||||
|
run_name=run_name,
|
||||||
|
)
|
||||||
85
src/state.py
Normal file
85
src/state.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
STATE_FILE = ".qc-cli.json"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CliStateStore:
|
||||||
|
config_dir: str = "."
|
||||||
|
|
||||||
|
@property
|
||||||
|
def path(self) -> Path:
|
||||||
|
return Path(self.config_dir) / STATE_FILE
|
||||||
|
|
||||||
|
def read(self) -> dict[str, Any]:
|
||||||
|
if not self.path.exists():
|
||||||
|
return {}
|
||||||
|
with open(self.path) as f:
|
||||||
|
value = json.load(f)
|
||||||
|
return dict(value) if isinstance(value, dict) else {}
|
||||||
|
|
||||||
|
def update(self, **updates: Any) -> None:
|
||||||
|
state = self.read()
|
||||||
|
state.update(updates)
|
||||||
|
self._write(state)
|
||||||
|
|
||||||
|
def get(self, key: str, default: Any = None) -> Any:
|
||||||
|
return self.read().get(key, default)
|
||||||
|
|
||||||
|
def get_last_training_job(self) -> str | None:
|
||||||
|
value = self.get("last_training_job")
|
||||||
|
return str(value) if value else None
|
||||||
|
|
||||||
|
def get_last_model_artifact(self) -> str | None:
|
||||||
|
value = self.get("last_model_artifact")
|
||||||
|
return str(value) if value else None
|
||||||
|
|
||||||
|
def get_last_optimized_model_id(self) -> str | None:
|
||||||
|
value = self.get("last_optimized_model_id")
|
||||||
|
return str(value) if value else None
|
||||||
|
|
||||||
|
def get_last_quantized_model_id(self) -> str | None:
|
||||||
|
value = self.get("last_quantized_model_id")
|
||||||
|
return str(value) if value else None
|
||||||
|
|
||||||
|
def get_last_compiled_model_id(self) -> str | None:
|
||||||
|
value = self.get("last_compiled_model_id")
|
||||||
|
return str(value) if value else None
|
||||||
|
|
||||||
|
def get_last_downloaded_model(self) -> str | None:
|
||||||
|
value = self.get("last_downloaded_model")
|
||||||
|
return str(value) if value else None
|
||||||
|
|
||||||
|
def set_last_training_job(self, job_name: str) -> None:
|
||||||
|
self.update(last_training_job=job_name)
|
||||||
|
|
||||||
|
def get_training_job(self, job_name: str) -> dict[str, Any]:
|
||||||
|
jobs = self._training_jobs(self.read())
|
||||||
|
value = jobs.get(job_name, {})
|
||||||
|
return dict(value) if isinstance(value, dict) else {}
|
||||||
|
|
||||||
|
def update_training_job(self, job_name: str, **updates: Any) -> None:
|
||||||
|
state = self.read()
|
||||||
|
jobs = self._training_jobs(state)
|
||||||
|
jobs[job_name] = {**jobs.get(job_name, {}), **updates}
|
||||||
|
state["training_jobs"] = jobs
|
||||||
|
self._write(state)
|
||||||
|
|
||||||
|
def set_latest_experiment_model_version(self, version: str) -> None:
|
||||||
|
self.update(latest_experiment_model_version=version)
|
||||||
|
|
||||||
|
def _write(self, state: dict[str, Any]) -> None:
|
||||||
|
with open(self.path, "w") as f:
|
||||||
|
json.dump(state, f, indent=2)
|
||||||
|
|
||||||
|
def _training_jobs(self, state: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
value = state.get("training_jobs", {})
|
||||||
|
return dict(value) if isinstance(value, dict) else {}
|
||||||
|
|
||||||
|
|
||||||
|
def store(config_path: str) -> CliStateStore:
|
||||||
|
config_dir = str(Path(config_path).parent)
|
||||||
|
return CliStateStore(config_dir)
|
||||||
3
src/tracking/__init__.py
Normal file
3
src/tracking/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from src.tracking.mlflow import FinalizeResult, MlflowTracker, NoopTracker, Tracker
|
||||||
|
|
||||||
|
__all__ = ["FinalizeResult", "MlflowTracker", "NoopTracker", "Tracker"]
|
||||||
93
src/tracking/metrics.py
Normal file
93
src/tracking/metrics.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
import json
|
||||||
|
import math
|
||||||
|
import tarfile
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import PurePosixPath
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
METRICS_ARTIFACT_NAME = "training_metrics.json"
|
||||||
|
METRICS_SCHEMA_VERSION = 1
|
||||||
|
MAX_METRICS_ARTIFACT_BYTES = 10 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MetricStep:
|
||||||
|
step: int
|
||||||
|
metrics: dict[str, float]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TrainingMetrics:
|
||||||
|
steps: list[MetricStep]
|
||||||
|
summary: dict[str, float]
|
||||||
|
raw: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_training_metrics(data: bytes) -> TrainingMetrics:
|
||||||
|
try:
|
||||||
|
value = json.loads(data)
|
||||||
|
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
|
||||||
|
raise ValueError(f"{METRICS_ARTIFACT_NAME} is not valid JSON") from exc
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
raise ValueError(f"{METRICS_ARTIFACT_NAME} must contain a JSON object")
|
||||||
|
if value.get("schema_version") != METRICS_SCHEMA_VERSION:
|
||||||
|
raise ValueError(f"Unsupported training metrics schema version: {value.get('schema_version')!r}")
|
||||||
|
|
||||||
|
raw_steps = value.get("steps")
|
||||||
|
if not isinstance(raw_steps, list):
|
||||||
|
raise ValueError("training metrics 'steps' must be a list")
|
||||||
|
|
||||||
|
steps: list[MetricStep] = []
|
||||||
|
previous_step: int | None = None
|
||||||
|
for index, raw_step in enumerate(raw_steps):
|
||||||
|
if not isinstance(raw_step, dict):
|
||||||
|
raise ValueError(f"training metrics step {index} must be an object")
|
||||||
|
step = raw_step.get("step")
|
||||||
|
if isinstance(step, bool) or not isinstance(step, int) or step < 0:
|
||||||
|
raise ValueError(f"training metrics step {index} has an invalid 'step'")
|
||||||
|
if previous_step is not None and step <= previous_step:
|
||||||
|
raise ValueError("training metrics steps must be unique and strictly increasing")
|
||||||
|
metrics = _numeric_metrics(raw_step.get("metrics"), f"training metrics step {step}")
|
||||||
|
steps.append(MetricStep(step=step, metrics=metrics))
|
||||||
|
previous_step = step
|
||||||
|
|
||||||
|
summary = _numeric_metrics(value.get("summary", {}), "training metrics summary")
|
||||||
|
return TrainingMetrics(steps=steps, summary=summary, raw=value)
|
||||||
|
|
||||||
|
|
||||||
|
def read_training_metrics_from_tar(archive_path: str) -> bytes | None:
|
||||||
|
with tarfile.open(archive_path, mode="r:*") as archive:
|
||||||
|
matches = [
|
||||||
|
member
|
||||||
|
for member in archive.getmembers()
|
||||||
|
if member.isfile() and PurePosixPath(member.name).name == METRICS_ARTIFACT_NAME
|
||||||
|
]
|
||||||
|
if not matches:
|
||||||
|
return None
|
||||||
|
if len(matches) > 1:
|
||||||
|
raise ValueError(f"Model archive contains multiple {METRICS_ARTIFACT_NAME} files")
|
||||||
|
if matches[0].size > MAX_METRICS_ARTIFACT_BYTES:
|
||||||
|
raise ValueError(
|
||||||
|
f"{METRICS_ARTIFACT_NAME} exceeds the {MAX_METRICS_ARTIFACT_BYTES}-byte size limit"
|
||||||
|
)
|
||||||
|
extracted = archive.extractfile(matches[0])
|
||||||
|
if extracted is None:
|
||||||
|
raise ValueError(f"Could not read {METRICS_ARTIFACT_NAME} from model archive")
|
||||||
|
return extracted.read()
|
||||||
|
|
||||||
|
|
||||||
|
def _numeric_metrics(value: Any, context: str) -> dict[str, float]:
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
raise ValueError(f"{context} 'metrics' must be an object")
|
||||||
|
|
||||||
|
metrics: dict[str, float] = {}
|
||||||
|
for raw_name, raw_value in value.items():
|
||||||
|
if not isinstance(raw_name, str) or not raw_name:
|
||||||
|
raise ValueError(f"{context} contains an invalid metric name")
|
||||||
|
if isinstance(raw_value, bool) or not isinstance(raw_value, int | float):
|
||||||
|
raise ValueError(f"{context} metric '{raw_name}' must be numeric")
|
||||||
|
metric_value = float(raw_value)
|
||||||
|
if not math.isfinite(metric_value):
|
||||||
|
raise ValueError(f"{context} metric '{raw_name}' must be finite")
|
||||||
|
metrics[raw_name] = metric_value
|
||||||
|
return metrics
|
||||||
267
src/tracking/mlflow.py
Normal file
267
src/tracking/mlflow.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
import mlflow
|
||||||
|
from mlflow.tracking import MlflowClient
|
||||||
|
|
||||||
|
from src.aws import s3
|
||||||
|
from src.cloud.mlflow import MlflowTrackingBackend, mlflow_tracking_backend_from_config
|
||||||
|
from src.config import Config, MlflowMode
|
||||||
|
from src.tracking.metrics import METRICS_ARTIFACT_NAME, parse_training_metrics, read_training_metrics_from_tar
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class FinalizeResult:
|
||||||
|
registered_model_version: str | None = None
|
||||||
|
warnings: tuple[str, ...] = ()
|
||||||
|
|
||||||
|
|
||||||
|
class Tracker(Protocol):
|
||||||
|
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: ...
|
||||||
|
|
||||||
|
def finalize_training_run(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
run_id: str | None,
|
||||||
|
training_job_status: Any,
|
||||||
|
region: str,
|
||||||
|
profile: str,
|
||||||
|
command: str,
|
||||||
|
) -> FinalizeResult: ...
|
||||||
|
|
||||||
|
def ensure_training_run(self, job_name: str) -> str: ...
|
||||||
|
|
||||||
|
def upload_training_metrics(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
run_id: str,
|
||||||
|
training_job_status: Any,
|
||||||
|
region: str,
|
||||||
|
profile: str,
|
||||||
|
) -> bool: ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class NoopTracker:
|
||||||
|
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def finalize_training_run(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
run_id: str | None,
|
||||||
|
training_job_status: Any,
|
||||||
|
region: str,
|
||||||
|
profile: str,
|
||||||
|
command: str,
|
||||||
|
) -> FinalizeResult:
|
||||||
|
return FinalizeResult()
|
||||||
|
|
||||||
|
def ensure_training_run(self, job_name: str) -> str:
|
||||||
|
raise RuntimeError("MLflow is disabled.")
|
||||||
|
|
||||||
|
def upload_training_metrics(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
run_id: str,
|
||||||
|
training_job_status: Any,
|
||||||
|
region: str,
|
||||||
|
profile: str,
|
||||||
|
) -> bool:
|
||||||
|
raise RuntimeError("MLflow is disabled.")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MlflowTracker:
|
||||||
|
tracking_uri: str
|
||||||
|
experiment_name: str
|
||||||
|
registered_model_name: str
|
||||||
|
register_trained_models: bool
|
||||||
|
tracking_backend: MlflowTrackingBackend
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, cfg: Config) -> Tracker:
|
||||||
|
if cfg.mlflow.mode is MlflowMode.disabled:
|
||||||
|
return NoopTracker()
|
||||||
|
|
||||||
|
os.environ.setdefault("MLFLOW_SUPPRESS_PRINTING_URL_TO_STDOUT", "true")
|
||||||
|
|
||||||
|
tracking_server_name = cfg.effective_mlflow_tracking_server_name
|
||||||
|
if not tracking_server_name:
|
||||||
|
raise RuntimeError("MLflow tracking server name could not be resolved.")
|
||||||
|
|
||||||
|
tracking_backend = mlflow_tracking_backend_from_config(cfg)
|
||||||
|
|
||||||
|
tracking_uri = tracking_backend.get_tracking_uri(tracking_server_name)
|
||||||
|
with tracking_backend.auth_env():
|
||||||
|
mlflow.set_tracking_uri(tracking_uri)
|
||||||
|
mlflow.set_experiment(cfg.mlflow.experiment_name)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
tracking_uri=tracking_uri,
|
||||||
|
experiment_name=cfg.mlflow.experiment_name,
|
||||||
|
registered_model_name=cfg.mlflow.registered_model_name,
|
||||||
|
register_trained_models=cfg.mlflow.register_trained_models,
|
||||||
|
tracking_backend=tracking_backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
|
||||||
|
with self.tracking_backend.auth_env():
|
||||||
|
with mlflow.start_run(run_name=training_job.job_name) as run:
|
||||||
|
run_id = str(run.info.run_id)
|
||||||
|
self._log_params(
|
||||||
|
self.tracking_backend.training_run_params(
|
||||||
|
training_job,
|
||||||
|
region=region,
|
||||||
|
profile=profile,
|
||||||
|
role_arn=role_arn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()})
|
||||||
|
mlflow.set_tags(
|
||||||
|
{
|
||||||
|
"qc_cli.stage": "experiment",
|
||||||
|
"qc_cli.artifact_kind": "trained_source",
|
||||||
|
"qc_cli.source": self.tracking_backend.provider_name,
|
||||||
|
"qc_cli.command": "train start",
|
||||||
|
**self.tracking_backend.training_run_tags(training_job),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return run_id
|
||||||
|
|
||||||
|
def finalize_training_run(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
run_id: str | None,
|
||||||
|
training_job_status: Any,
|
||||||
|
region: str,
|
||||||
|
profile: str,
|
||||||
|
command: str,
|
||||||
|
) -> FinalizeResult:
|
||||||
|
if not run_id:
|
||||||
|
return FinalizeResult()
|
||||||
|
|
||||||
|
with self.tracking_backend.auth_env():
|
||||||
|
with mlflow.start_run(run_id=run_id):
|
||||||
|
self._log_params(self.tracking_backend.training_status_params(training_job_status))
|
||||||
|
self._log_final_metrics(training_job_status.raw)
|
||||||
|
mlflow.set_tag("qc_cli.command", command)
|
||||||
|
|
||||||
|
if training_job_status.status != "Completed" or not training_job_status.model_artifacts:
|
||||||
|
mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status)
|
||||||
|
return FinalizeResult()
|
||||||
|
|
||||||
|
if not self.register_trained_models:
|
||||||
|
return FinalizeResult()
|
||||||
|
|
||||||
|
client = MlflowClient()
|
||||||
|
self._ensure_registered_model(client, self.registered_model_name)
|
||||||
|
version = client.create_model_version(
|
||||||
|
name=self.registered_model_name,
|
||||||
|
source=training_job_status.model_artifacts,
|
||||||
|
run_id=run_id,
|
||||||
|
tags={
|
||||||
|
"qc_cli.stage": "experiment",
|
||||||
|
"qc_cli.artifact_kind": "trained_source",
|
||||||
|
"qc_cli.source": self.tracking_backend.provider_name,
|
||||||
|
**self.tracking_backend.model_version_tags(training_job_status),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
version_number = str(version.version)
|
||||||
|
client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number)
|
||||||
|
mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name)
|
||||||
|
mlflow.set_tag("qc_cli.registered_model_version", version_number)
|
||||||
|
return FinalizeResult(registered_model_version=version_number)
|
||||||
|
|
||||||
|
def ensure_training_run(self, job_name: str) -> str:
|
||||||
|
with self.tracking_backend.auth_env():
|
||||||
|
client = MlflowClient()
|
||||||
|
experiment = client.get_experiment_by_name(self.experiment_name)
|
||||||
|
if experiment is None:
|
||||||
|
experiment_id = mlflow.create_experiment(self.experiment_name)
|
||||||
|
else:
|
||||||
|
experiment_id = experiment.experiment_id
|
||||||
|
|
||||||
|
for run in client.search_runs([experiment_id], max_results=1000):
|
||||||
|
if run.data.tags.get("sagemaker.job_name") == job_name:
|
||||||
|
return str(run.info.run_id)
|
||||||
|
|
||||||
|
run = client.create_run(
|
||||||
|
experiment_id,
|
||||||
|
run_name=job_name,
|
||||||
|
tags={
|
||||||
|
"qc_cli.stage": "experiment",
|
||||||
|
"qc_cli.artifact_kind": "trained_source",
|
||||||
|
"qc_cli.source": self.tracking_backend.provider_name,
|
||||||
|
"qc_cli.command": "mlflow upload-metrics",
|
||||||
|
"sagemaker.job_name": job_name,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return str(run.info.run_id)
|
||||||
|
|
||||||
|
def upload_training_metrics(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
run_id: str,
|
||||||
|
training_job_status: Any,
|
||||||
|
region: str,
|
||||||
|
profile: str,
|
||||||
|
) -> bool:
|
||||||
|
if not training_job_status.model_artifacts:
|
||||||
|
raise ValueError(f"Training job '{training_job_status.name}' has no model artifacts.")
|
||||||
|
|
||||||
|
with self.tracking_backend.auth_env():
|
||||||
|
with mlflow.start_run(run_id=run_id):
|
||||||
|
self._log_params(self.tracking_backend.training_status_params(training_job_status))
|
||||||
|
self._log_final_metrics(training_job_status.raw)
|
||||||
|
history_uploaded = self._log_training_metrics(
|
||||||
|
training_job_status.model_artifacts,
|
||||||
|
region=region,
|
||||||
|
profile=profile,
|
||||||
|
)
|
||||||
|
mlflow.set_tag("qc_cli.command", "mlflow upload-metrics")
|
||||||
|
mlflow.set_tag("qc_cli.metrics_history_uploaded", str(history_uploaded).lower())
|
||||||
|
return history_uploaded
|
||||||
|
|
||||||
|
def _log_params(self, params: dict[str, Any]) -> None:
|
||||||
|
cleaned = {key: str(value) for key, value in params.items() if value is not None}
|
||||||
|
if cleaned:
|
||||||
|
mlflow.log_params(cleaned)
|
||||||
|
|
||||||
|
def _log_final_metrics(self, training_job: dict[str, Any]) -> None:
|
||||||
|
metrics = {}
|
||||||
|
for metric in training_job.get("FinalMetricDataList", []):
|
||||||
|
name = metric.get("MetricName")
|
||||||
|
value = metric.get("Value")
|
||||||
|
if name and value is not None:
|
||||||
|
metrics[str(name)] = float(value)
|
||||||
|
if metrics:
|
||||||
|
mlflow.log_metrics(metrics)
|
||||||
|
|
||||||
|
def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> bool:
|
||||||
|
with tempfile.TemporaryDirectory(prefix="qc-cli-metrics-") as temp_dir:
|
||||||
|
archive_path = s3.download_file(
|
||||||
|
region,
|
||||||
|
profile,
|
||||||
|
model_artifacts,
|
||||||
|
os.path.join(temp_dir, "model.tar.gz"),
|
||||||
|
)
|
||||||
|
metrics_data = read_training_metrics_from_tar(archive_path)
|
||||||
|
if metrics_data is None:
|
||||||
|
return False
|
||||||
|
metrics = parse_training_metrics(metrics_data)
|
||||||
|
for metric_step in metrics.steps:
|
||||||
|
if metric_step.metrics:
|
||||||
|
mlflow.log_metrics(metric_step.metrics, step=metric_step.step)
|
||||||
|
if metrics.summary:
|
||||||
|
mlflow.log_metrics(metrics.summary)
|
||||||
|
mlflow.log_dict(metrics.raw, METRICS_ARTIFACT_NAME)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _ensure_registered_model(self, client: MlflowClient, name: str) -> None:
|
||||||
|
try:
|
||||||
|
client.get_registered_model(name)
|
||||||
|
except Exception:
|
||||||
|
client.create_registered_model(name)
|
||||||
75
src/tracking/upload.py
Normal file
75
src/tracking/upload.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from src import state as state_ops
|
||||||
|
from src.aws import sagemaker as sm_ops
|
||||||
|
from src.config import Config, MlflowMode
|
||||||
|
from src.tracking.mlflow import MlflowTracker
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MetricsUploadResult:
|
||||||
|
run_id: str
|
||||||
|
registered_model_version: str | None = None
|
||||||
|
metrics_history_uploaded: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
def upload_training_metrics(
|
||||||
|
*,
|
||||||
|
job_name: str,
|
||||||
|
config_path: str,
|
||||||
|
cfg: Config,
|
||||||
|
force: bool = False,
|
||||||
|
) -> MetricsUploadResult:
|
||||||
|
if cfg.mlflow.mode is MlflowMode.disabled:
|
||||||
|
raise RuntimeError("MLflow is disabled in config.yaml.")
|
||||||
|
|
||||||
|
st = state_ops.store(config_path)
|
||||||
|
job_state = st.get_training_job(job_name)
|
||||||
|
if job_state.get("mlflow_metrics_uploaded") and not force:
|
||||||
|
return MetricsUploadResult(
|
||||||
|
run_id=str(job_state.get("mlflow_run_id") or ""),
|
||||||
|
registered_model_version=(
|
||||||
|
str(job_state["registered_model_version"])
|
||||||
|
if job_state.get("registered_model_version")
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
metrics_history_uploaded=bool(job_state.get("mlflow_metrics_history_uploaded", True)),
|
||||||
|
)
|
||||||
|
|
||||||
|
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
||||||
|
if status.status != "Completed":
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Training job {job_name} is {status.status}; metrics can be uploaded only after completion."
|
||||||
|
)
|
||||||
|
|
||||||
|
tracker = MlflowTracker.from_config(cfg)
|
||||||
|
run_id = str(job_state.get("mlflow_run_id") or tracker.ensure_training_run(job_name))
|
||||||
|
st.update_training_job(job_name, mlflow_run_id=run_id)
|
||||||
|
metrics_history_uploaded = tracker.upload_training_metrics(
|
||||||
|
run_id=run_id,
|
||||||
|
training_job_status=status,
|
||||||
|
region=cfg.aws.region,
|
||||||
|
profile=cfg.aws.profile,
|
||||||
|
)
|
||||||
|
finalized = tracker.finalize_training_run(
|
||||||
|
run_id=run_id,
|
||||||
|
training_job_status=status,
|
||||||
|
region=cfg.aws.region,
|
||||||
|
profile=cfg.aws.profile,
|
||||||
|
command="mlflow upload-metrics",
|
||||||
|
)
|
||||||
|
updates = {
|
||||||
|
"mlflow_metrics_uploaded": True,
|
||||||
|
"mlflow_metrics_history_uploaded": metrics_history_uploaded,
|
||||||
|
"mlflow_finalized_status": status.status,
|
||||||
|
}
|
||||||
|
if finalized.registered_model_version:
|
||||||
|
updates["registered_model_version"] = finalized.registered_model_version
|
||||||
|
st.update_training_job(job_name, **updates)
|
||||||
|
if finalized.registered_model_version:
|
||||||
|
st.set_latest_experiment_model_version(finalized.registered_model_version)
|
||||||
|
return MetricsUploadResult(
|
||||||
|
run_id=run_id,
|
||||||
|
registered_model_version=finalized.registered_model_version,
|
||||||
|
metrics_history_uploaded=metrics_history_uploaded,
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user