Compare commits
12 Commits
b907a74525
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| b56b77330c | |||
| a1ffbb77c5 | |||
| 522ddc74e2 | |||
|
|
5360a482fc | ||
|
|
6a560a8610 | ||
| d244150d98 | |||
| d7c7158464 | |||
| 6bc25dc183 | |||
|
|
71a95aa3a7 | ||
| a3f3060e13 | |||
| e9ada2612f | |||
| 6ac9702dc5 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -218,7 +218,7 @@ __marimo__/
|
||||
.streamlit/secrets.toml
|
||||
|
||||
.venv/
|
||||
config.yaml
|
||||
config*.yaml
|
||||
cdk.out/
|
||||
.qc-cli*.json
|
||||
examples/*/data/
|
||||
|
||||
137
README.md
137
README.md
@@ -30,7 +30,7 @@ qc-cli --help
|
||||
# 1. Create config.yaml in the current directory
|
||||
qc-cli init
|
||||
|
||||
# 2. Edit config.yaml — at minimum set s3.bucket and sagemaker.training.image_uri
|
||||
# 2. Edit config.yaml — at minimum set sagemaker.training.image_uri
|
||||
|
||||
# 3. Provision AWS infrastructure (S3 bucket + SageMaker IAM role).
|
||||
# This is the step that requires the AWS CDK CLI.
|
||||
@@ -47,15 +47,17 @@ qc-cli train status
|
||||
`qc-cli init` writes a `config.yaml` in the current directory. The fields you must fill in before using the tool:
|
||||
|
||||
```yaml
|
||||
infra:
|
||||
stack_name: qc-cli-mlops-1a2b3c4d5e6f
|
||||
|
||||
aws:
|
||||
region: us-east-1
|
||||
profile: default # AWS CLI profile name
|
||||
|
||||
s3:
|
||||
bucket: your-unique-bucket-name
|
||||
bucket: qc-cli-mlops-1a2b3c4d5e6f-data
|
||||
|
||||
sagemaker:
|
||||
role_name: qc-cli-sagemaker-role
|
||||
training:
|
||||
image_uri: "" # ECR URI for your training container
|
||||
instance_type: ml.m5.xlarge
|
||||
@@ -63,8 +65,24 @@ sagemaker:
|
||||
entry_point: null # Optional: script inside source_dir
|
||||
source_dir: null # Optional: local dir packaged and uploaded automatically
|
||||
hyperparameters: {}
|
||||
|
||||
aihub:
|
||||
device:
|
||||
name: Samsung Galaxy S25 (Family)
|
||||
target_runtime: tflite
|
||||
input_specs: {} # Required before running qc-cli ai-hub commands
|
||||
job_name: null # Optional prefix for AI Hub Workbench jobs
|
||||
model_name: null # Optional name for uploaded local ONNX models
|
||||
compile_options: null
|
||||
profile_options: null
|
||||
quantize_options: null
|
||||
output_dir: build/qai-hub
|
||||
```
|
||||
|
||||
`qc-cli init` generates the `infra.stack_name` and `s3.bucket` namespace once and writes it to `config.yaml`. Keep these values stable for a deployment; changing them points the CLI at different infrastructure.
|
||||
|
||||
The CLI isolates both application resources and CDK bootstrap resources. The application CloudFormation stack uses `infra.stack_name`, the S3 bucket uses the same generated namespace because bucket names are globally unique, and the SageMaker IAM role uses a CloudFormation-generated physical name. CDK bootstrap resources are derived internally from `infra.stack_name`, including a bootstrap stack named `<stack_name>-bootstrap` and a matching non-default CDK asset bucket qualifier. `qc-cli infra destroy` removes the application stack but leaves the CDK bootstrap stack in place; the command prints the retained bootstrap stack name.
|
||||
|
||||
`hyperparameters` is a flat map of values passed to the training container. Valid keys depend on the selected training image and entry point.
|
||||
|
||||
To provision an MLflow tracking server, set:
|
||||
@@ -72,12 +90,13 @@ To provision an MLflow tracking server, set:
|
||||
```yaml
|
||||
mlflow:
|
||||
mode: create
|
||||
tracking_server_name: your-tracking-server-name
|
||||
experiment_name: qc-cli-training
|
||||
registered_model_name: qc-cli-model
|
||||
register_trained_models: true
|
||||
```
|
||||
|
||||
In `create` mode, the CLI manages the tracking server name from `infra.stack_name`; you do not need to set `tracking_server_name`.
|
||||
|
||||
To use an existing MLflow tracking server, set:
|
||||
|
||||
```yaml
|
||||
@@ -86,13 +105,15 @@ mlflow:
|
||||
tracking_server_name: your-tracking-server-name
|
||||
```
|
||||
|
||||
Install the optional MLflow dependencies before enabling MLflow:
|
||||
When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. Training metrics can be upload with `train start --upload-metrics` or `mlflow upload-metrics`.
|
||||
|
||||
To open the managed SageMaker MLflow UI, request a fresh presigned URL:
|
||||
|
||||
```bash
|
||||
uv sync --extra mlflow
|
||||
qc-cli mlflow open --config config.yaml
|
||||
```
|
||||
|
||||
When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. `train status` finalizes that run once the job reaches a terminal state and registers completed model artifacts as pre-release model versions using the `prerelease-latest` MLflow alias.
|
||||
This opens a browser to a fresh presigned URL. It works for `mode: create` and for `mode: existing` when the existing server is managed by Amazon SageMaker. In `create` mode, the command uses the CLI-managed tracking server name. In `existing` mode, it uses `mlflow.tracking_server_name`. If the existing MLflow server is external to SageMaker, open it with that server's own URL instead.
|
||||
|
||||
## Commands
|
||||
|
||||
@@ -116,6 +137,21 @@ qc-cli infra destroy --yes Destroy stack without confirmation
|
||||
qc-cli infra destroy --delete-bucket-data Destroy stack and delete S3 data
|
||||
```
|
||||
|
||||
`--cloudformation-execution-policy` is a one-time CDK bootstrap option, not a `config.yaml` setting. Pass it on `infra setup` when you need the CDK bootstrap CloudFormation execution role to use a policy other than the default `AdministratorAccess`:
|
||||
|
||||
```bash
|
||||
qc-cli infra setup --cloudformation-execution-policy arn:aws:iam::aws:policy/PowerUserAccess
|
||||
```
|
||||
|
||||
### `mlflow`
|
||||
|
||||
```
|
||||
qc-cli mlflow open Open a presigned MLflow UI URL
|
||||
qc-cli mlflow upload-metrics [job-name] Upload completed training metrics
|
||||
```
|
||||
|
||||
`mlflow upload-metrics` defaults to the last submitted training job. It creates or recovers the job's MLflow run, imports `training_metrics.json` from the SageMaker model artifact, and records successful upload in `.qc-cli.json`. Use `--force` to upload the metrics again.
|
||||
|
||||
### `upload`
|
||||
|
||||
```
|
||||
@@ -130,6 +166,7 @@ Uploads use `s3.bucket` and `s3.data_prefix` from `config.yaml`. File uploads de
|
||||
|
||||
```
|
||||
qc-cli train start Submit a SageMaker training job
|
||||
qc-cli train start --upload-metrics Submit, wait, and upload metrics
|
||||
qc-cli train status [job-name] Show job status; defaults to the last submitted job
|
||||
qc-cli train list List recent training jobs
|
||||
qc-cli train list --limit 3 Show a custom number of recent jobs
|
||||
@@ -137,8 +174,94 @@ qc-cli train list --limit 3 Show a custom number of recent jobs
|
||||
|
||||
`train start` uses `s3://<bucket>/<data_prefix>/` as the training channel and writes outputs under `s3://<bucket>/<model_prefix>/`. If `sagemaker.training.source_dir` is set, the CLI packages that directory, uploads it beside the job output prefix, and passes `sagemaker_program`/`sagemaker_submit_directory` to the SageMaker container.
|
||||
|
||||
`train start --upload-metrics` checks SageMaker every 30 seconds by default, then uploads metrics after completion. Use `--poll-interval <seconds>` to choose another positive interval. Stopping the local command does not stop the SageMaker job.
|
||||
|
||||
The expected output artifact is SageMaker’s `model.tar.gz`, normally containing the trained model file your container writes to `/opt/ml/model`.
|
||||
|
||||
### `ai-hub`
|
||||
|
||||
```
|
||||
qc-cli ai-hub upload <calibration.npz|calibration-dir> <inputs.npz|inputs.npy>
|
||||
qc-cli ai-hub upload <calibration> <inputs> --from-step validate
|
||||
qc-cli ai-hub optimize [--onnx-path PATH] [--model-s3-uri URI] [--from-job NAME]
|
||||
qc-cli ai-hub quantize <calibration.npz|calibration-dir> [--model-id ID] [--onnx-path PATH] [--model-s3-uri URI] [--from-job NAME]
|
||||
qc-cli ai-hub compile [--model-id ID] [--onnx-path PATH] [--model-s3-uri URI] [--from-job NAME]
|
||||
qc-cli ai-hub validate <inputs.npz|inputs.npy> [--model-id ID] [--input-name NAME]
|
||||
qc-cli ai-hub profile [--model-id ID]
|
||||
qc-cli ai-hub download [--model-id ID] [--output PATH]
|
||||
```
|
||||
|
||||
`ai-hub upload` optimizes to ONNX, quantizes, validates, and profiles. When `aihub.target_runtime` is not `onnx`, it also compiles the quantized model to that deployment runtime. The initial ONNX optimization gives external models Workbench provenance and applies compiler optimization passes before quantization.
|
||||
|
||||
Resume behavior:
|
||||
|
||||
```text
|
||||
--from-step optimize Run optimize, quantize, optional final compile, validate, and profile.
|
||||
--from-step quantize Quantize the last optimized ONNX, then optionally compile, validate, and profile.
|
||||
--from-step compile Skip optimize and quantize; finalize the last quantized model for the target runtime.
|
||||
--from-step validate Skip optimize, quantize, and compile; validate the last compiled model.
|
||||
--from-step profile Skip optimize, quantize, compile, and validate; profile the last compiled model.
|
||||
```
|
||||
|
||||
When a step runs in the current command, `upload` passes its returned model ID directly to the next step. When a step is skipped, the next step resolves the needed model ID from `.qc-cli.json`. This avoids re-running earlier AI Hub jobs when you only need to continue from a later step.
|
||||
|
||||
`ai-hub optimize` compiles an external model with `--target_runtime onnx`. `ai-hub quantize` uses an explicit `--model-id`, the last optimized ONNX model, or an explicit/local model source in that order. `ai-hub compile` resolves model sources in this order: `--model-id`, explicit source options, last quantized model, then the last training job. For `target_runtime: onnx`, upload treats the quantized ONNX as the final model and skips a redundant second compile. `ai-hub download` remains separate because downloading is outside the Workbench processing loop.
|
||||
|
||||
AI Hub authentication currently uses the local `qai-hub` SDK configuration. A planned follow-up is to support AWS Systems Manager Parameter Store `SecureString` for team-managed tokens, where `config.yaml` stores only a parameter name such as `/qc-cli/aihub/token`, AWS KMS encrypts the token at rest, and the CLI retrieves it at runtime with `ssm:GetParameter` plus `kms:Decrypt` permissions.
|
||||
|
||||
## Model lifecycle
|
||||
|
||||
The CLI uses neutral experiment naming for trained artifacts and reserves release terminology for an explicit promotion step.
|
||||
|
||||
Current behavior:
|
||||
|
||||
1. `qc-cli train start` submits a SageMaker training job.
|
||||
2. `qc-cli train status` reads and displays SageMaker status only; it does not contact MLflow.
|
||||
3. `qc-cli train start --upload-metrics` polls every 30 seconds by default, then uploads per-epoch metrics after completion.
|
||||
4. `qc-cli mlflow upload-metrics [job-name]` uploads or retries metrics for an existing completed job.
|
||||
5. The metrics upload workflow finalizes the MLflow run and, when `mlflow.register_trained_models` is enabled, registers the SageMaker `model.tar.gz` as a new MLflow model version with:
|
||||
- `qc_cli.stage=experiment`
|
||||
- `qc_cli.artifact_kind=trained_source`
|
||||
- `qc_cli.source=sagemaker`
|
||||
6. The MLflow alias `experiment-latest` points at the most recently registered experiment version.
|
||||
7. AI Hub upload commands create deployable derived artifacts from a trained-source experiment or local ONNX model.
|
||||
|
||||
Training scripts can include a `training_metrics.json` file in the SageMaker model directory. When present, the explicit metrics upload command logs its ordered metrics to the associated MLflow run using each epoch as the MLflow step and stores the JSON as a run artifact:
|
||||
|
||||
```json
|
||||
{
|
||||
"schema_version": 1,
|
||||
"steps": [
|
||||
{"step": 0, "metrics": {"val.precision": 0.72, "val.recall": 0.68}}
|
||||
],
|
||||
"summary": {"summary.best_epoch": 0}
|
||||
}
|
||||
```
|
||||
|
||||
Metric names must be non-empty strings, values must be finite numbers, and steps must be non-negative, unique, and strictly increasing. If the file is missing, the command uploads the final metrics reported by SageMaker and continues model registration without per-epoch history. A malformed metrics artifact still fails the upload command without affecting the trained model or model registration.
|
||||
|
||||
Future release aliases such as `v1` or `production` can point at a selected deployable artifact.
|
||||
|
||||
Example future metadata:
|
||||
|
||||
```text
|
||||
qc-cli-model version 12
|
||||
qc_cli.stage=experiment
|
||||
qc_cli.artifact_kind=trained_source
|
||||
qc_cli.source=sagemaker
|
||||
|
||||
qc-cli-model-aihub version 3
|
||||
qc_cli.stage=ai_hub_compiled
|
||||
qc_cli.artifact_kind=deployable
|
||||
qc_cli.parent_registered_model_name=qc-cli-model
|
||||
qc_cli.parent_model_version=12
|
||||
qc_cli.runtime=tflite
|
||||
qc_cli.quantization=int8
|
||||
qc_cli.target_device=Samsung Galaxy S25
|
||||
```
|
||||
|
||||
In that flow, `experiment-latest` remains a training convenience alias. Release selection is a separate promotion decision based on the derived artifact, not on the experiment name.
|
||||
|
||||
## AWS permissions required
|
||||
|
||||
The IAM user or role running the CLI needs:
|
||||
|
||||
4
app.py
4
app.py
@@ -8,17 +8,19 @@ from src.infra.stack import QCStack
|
||||
app = cdk.App()
|
||||
|
||||
config_path = app.node.try_get_context("config") or "config.yaml"
|
||||
stack_name = app.node.try_get_context("stack_name") or "MLOpsStack"
|
||||
account_id = app.node.try_get_context("account_id") or os.getenv("CDK_DEFAULT_ACCOUNT")
|
||||
delete_bucket_data = str(app.node.try_get_context("delete_bucket_data") or "false").lower() == "true"
|
||||
|
||||
cfg = load_config(config_path)
|
||||
stack_name = app.node.try_get_context("stack_name") or cfg.infra.stack_name
|
||||
bootstrap_qualifier = app.node.try_get_context("bootstrap_qualifier") or cfg.infra.effective_bootstrap_qualifier
|
||||
|
||||
QCStack(
|
||||
app,
|
||||
stack_name,
|
||||
config=cfg,
|
||||
delete_bucket_data=delete_bucket_data,
|
||||
synthesizer=cdk.DefaultStackSynthesizer(qualifier=bootstrap_qualifier),
|
||||
env=cdk.Environment(
|
||||
account=account_id,
|
||||
region=cfg.aws.region,
|
||||
|
||||
285
examples/meter-detection/README.md
Normal file
285
examples/meter-detection/README.md
Normal file
@@ -0,0 +1,285 @@
|
||||
# YOLO26 Electric Meter Detection Example
|
||||
|
||||
This example trains a YOLO26 object detection model on the Roboflow Universe electric meter dataset using the existing `qc-cli` SageMaker training flow.
|
||||
|
||||
The workflow is intentionally command driven. Run each step yourself so you can inspect the dataset, update `config.yaml`, and decide when to submit the SageMaker job.
|
||||
|
||||
Dataset:
|
||||
|
||||
```text
|
||||
https://universe.roboflow.com/kemals-workspace-kbc8l/electric-meter-detection-o4tfi/dataset/1
|
||||
```
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Install or sync the project dependencies: `uv sync`
|
||||
- The virtual environment is activated.
|
||||
- AWS credentials configured for the profile in `config.yaml`
|
||||
- Infrastructure already deployed with `qc-cli infra setup`
|
||||
|
||||
## 1. Download The Dataset
|
||||
|
||||
Register or sign in to Roboflow, then open the dataset page:
|
||||
|
||||
```text
|
||||
https://universe.roboflow.com/kemals-workspace-kbc8l/electric-meter-detection-o4tfi/dataset/1
|
||||
```
|
||||
|
||||
Download the dataset in YOLOv26 format from the Roboflow UI, then extract the downloaded archive into:
|
||||
|
||||
```text
|
||||
examples/meter-detection/data/electric-meter-detection
|
||||
```
|
||||
|
||||
The `data.yaml` file should be directly under that folder:
|
||||
|
||||
```text
|
||||
examples/meter-detection/data/electric-meter-detection/data.yaml
|
||||
```
|
||||
|
||||
Do not move `data.yaml` into the `train/` split folder.
|
||||
|
||||
After extracting, confirm the dataset has a YOLO data file and image splits:
|
||||
|
||||
```bash
|
||||
find examples/meter-detection/data/electric-meter-detection -maxdepth 2 -type d | sort
|
||||
find examples/meter-detection/data/electric-meter-detection -name data.yaml -print
|
||||
```
|
||||
|
||||
Open `examples/meter-detection/data/electric-meter-detection/data.yaml` and make sure the split paths are relative to that folder:
|
||||
|
||||
```yaml
|
||||
path: .
|
||||
train: train/images
|
||||
val: valid/images
|
||||
test: test/images
|
||||
```
|
||||
|
||||
If your downloaded dataset does not include a `test/` folder, remove the `test:` line.
|
||||
|
||||
The expected layout is similar to:
|
||||
|
||||
```text
|
||||
examples/meter-detection/data/electric-meter-detection/
|
||||
data.yaml
|
||||
train/
|
||||
valid/
|
||||
test/
|
||||
```
|
||||
|
||||
## 2. Configure SageMaker Training
|
||||
|
||||
Update `config.yaml` so the training section points at this example's source directory:
|
||||
|
||||
```yaml
|
||||
sagemaker:
|
||||
training:
|
||||
image_uri: 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.6-cpu-py312-ubuntu22.04-sagemaker-v1
|
||||
instance_type: ml.g4dn.xlarge
|
||||
instance_count: 1
|
||||
source_dir: examples/meter-detection/source
|
||||
entry_point: train.py
|
||||
hyperparameters:
|
||||
model: yolo26n.pt
|
||||
epochs: 25
|
||||
imgsz: 640
|
||||
batch: 16
|
||||
workers: 2
|
||||
```
|
||||
|
||||
Use `yolo26n.pt` for a lightweight first YOLO26 run. If those weights are unavailable in the installed Ultralytics package, use `yolo11n.pt` as the established fallback:
|
||||
|
||||
```yaml
|
||||
model: yolo11n.pt
|
||||
```
|
||||
|
||||
The `source/requirements.txt` file is installed by the SageMaker PyTorch container before running `train.py`.
|
||||
|
||||
For a CPU smoke test, use a CPU instance and reduce the workload:
|
||||
|
||||
```yaml
|
||||
sagemaker:
|
||||
training:
|
||||
image_uri: 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.6-cpu-py312-ubuntu22.04-sagemaker-v1
|
||||
instance_type: ml.m4.xlarge
|
||||
instance_count: 1
|
||||
source_dir: examples/meter-detection/source
|
||||
entry_point: train.py
|
||||
hyperparameters:
|
||||
model: yolo26n.pt
|
||||
epochs: 1
|
||||
imgsz: 320
|
||||
batch: 4
|
||||
workers: 2
|
||||
```
|
||||
|
||||
## 3. Check Infrastructure
|
||||
|
||||
Confirm the CLI can see the configured SageMaker role and S3 bucket:
|
||||
|
||||
```bash
|
||||
qc-cli infra status
|
||||
```
|
||||
|
||||
## 4. Upload The Dataset
|
||||
|
||||
Upload the downloaded Roboflow dataset to the `s3.data_prefix` configured in `config.yaml`:
|
||||
|
||||
```bash
|
||||
qc-cli upload examples/meter-detection/data/electric-meter-detection
|
||||
```
|
||||
|
||||
Directory uploads preserve paths relative to the uploaded directory, so SageMaker receives the dataset root with `data.yaml` plus the split directories.
|
||||
|
||||
In SageMaker, this uploaded dataset root is mounted at `/opt/ml/input/data/train`. That `train` path is the SageMaker channel name, not the YOLO `train/` split folder.
|
||||
|
||||
## 5. Start Training
|
||||
|
||||
Submit the SageMaker training job:
|
||||
|
||||
```bash
|
||||
qc-cli train start
|
||||
```
|
||||
|
||||
The command prints the submitted SageMaker job name. Check progress with:
|
||||
|
||||
```bash
|
||||
qc-cli train status
|
||||
```
|
||||
|
||||
Or pass the job name explicitly:
|
||||
|
||||
```bash
|
||||
qc-cli train status qc-cli-YYYYMMDD-HHMMSS
|
||||
```
|
||||
|
||||
To submit the job, wait for completion, and automatically import metrics and register the model, run:
|
||||
|
||||
```bash
|
||||
qc-cli train start --upload-metrics
|
||||
```
|
||||
|
||||
The default polling interval is 30 seconds. It can be changed with `--poll-interval <seconds>`.
|
||||
|
||||
The metrics can be also submitted using:
|
||||
|
||||
```bash
|
||||
qc-cli mlflow upload-metrics
|
||||
```
|
||||
|
||||
## SageMaker Outputs
|
||||
|
||||
When the job completes, SageMaker packages the files written under `/opt/ml/model` into `model.tar.gz`.
|
||||
|
||||
This example writes:
|
||||
|
||||
```text
|
||||
best.pt
|
||||
model.onnx
|
||||
metrics.json
|
||||
training_metrics.json
|
||||
```
|
||||
|
||||
The archive is stored under the configured `s3.model_prefix`.
|
||||
|
||||
The `mlflow upload-metrics` command imports `training_metrics.json`, which provides per-epoch training and validation
|
||||
losses, precision, recall, mAP@0.50, mAP@0.50:0.95, and learning rates. For object detection, mAP and precision/recall
|
||||
are more meaningful than classification accuracy when assessing model quality.
|
||||
|
||||
## 6. Configure Qualcomm AI Hub
|
||||
|
||||
Authenticate with Qualcomm AI Hub:
|
||||
|
||||
```bash
|
||||
qai-hub configure --api_token
|
||||
```
|
||||
|
||||
Add AI Hub settings to `config.yaml`. The input name and image size must match the ONNX model exported by this example:
|
||||
|
||||
```yaml
|
||||
aihub:
|
||||
device:
|
||||
name: Dragonwing IQ-9075 EVK
|
||||
target_runtime: onnx
|
||||
input_specs:
|
||||
images: [[1, 3, 640, 640], float32]
|
||||
job_name: meter-detection
|
||||
model_name: meter-detection
|
||||
output_dir: build/qai-hub/meter-detection
|
||||
```
|
||||
|
||||
The ONNX graph is the source of truth. The export normally uses the same value as `sagemaker.training.hyperparameters.imgsz`, but changing `config.yaml` after training does not resize an existing model. For example, a model exported with `imgsz: 320` requires `images: [[1, 3, 320, 320], float32]`.
|
||||
|
||||
## 7. Prepare AI Hub Inputs
|
||||
|
||||
Generate calibration samples and a validation input from the downloaded dataset:
|
||||
|
||||
```bash
|
||||
uv run python examples/meter-detection/prepare_aihub_inputs.py --image-size 640
|
||||
```
|
||||
|
||||
This writes:
|
||||
|
||||
```text
|
||||
examples/meter-detection/data/aihub_calibration/*.npy
|
||||
examples/meter-detection/data/inputs.npz
|
||||
```
|
||||
|
||||
The script applies the preprocessing expected by the exported YOLO model: aspect-ratio-preserving letterboxing, RGB channel order, channel-first layout, and pixel values normalized to `[0, 1]`.
|
||||
|
||||
## 8. Upload To Qualcomm AI Hub
|
||||
|
||||
Use the SageMaker job name printed by `qc-cli train start`:
|
||||
|
||||
```bash
|
||||
qc-cli ai-hub upload \
|
||||
examples/meter-detection/data/aihub_calibration \
|
||||
examples/meter-detection/data/inputs.npz \
|
||||
--from-job qc-cli-YYYYMMDD-HHMMSS
|
||||
```
|
||||
|
||||
The command downloads the job's `model.tar.gz`, finds `model.onnx`, and runs the following AI Hub workflow:
|
||||
|
||||
1. Compile the external ONNX to a Workbench-optimized ONNX model.
|
||||
2. Quantize the optimized ONNX model.
|
||||
3. Compile the quantized model when the configured deployment runtime is not `onnx`.
|
||||
4. Validate and profile the final model.
|
||||
|
||||
The training example sanitizes the Ultralytics ONNX export before saving `model.onnx`. This removes graph input or output names, such as `output0`, that are duplicated in the ONNX `value_info` metadata and rejected by AI Hub.
|
||||
|
||||
For a model already downloaded by a failed upload attempt, sanitize the extracted ONNX file and retry using the local model. Replace the job name in both paths:
|
||||
|
||||
```bash
|
||||
uv run --with onnx python examples/meter-detection/source/sanitize_onnx.py \
|
||||
build/qai-hub/meter-detection/qc-cli-YYYYMMDD-HHMMSS/source/extracted/model.onnx \
|
||||
--output build/qai-hub/meter-detection/model.aihub.onnx
|
||||
|
||||
qc-cli ai-hub upload \
|
||||
examples/meter-detection/data/aihub_calibration \
|
||||
examples/meter-detection/data/inputs.npz \
|
||||
--onnx-path build/qai-hub/meter-detection/model.aihub.onnx
|
||||
```
|
||||
|
||||
Download the compiled artifact after the workflow completes:
|
||||
|
||||
```bash
|
||||
qc-cli ai-hub download --output build/qai-hub/meter-detection/model.tflite
|
||||
```
|
||||
|
||||
## Training Hyperparameters
|
||||
|
||||
Values under `sagemaker.training.hyperparameters` are passed to `source/train.py` as command-line arguments.
|
||||
|
||||
| Name | Type | Default | Description |
|
||||
|---|---:|---:|---|
|
||||
| `model` | string | `yolo26n.pt` | Ultralytics model weights or model YAML. |
|
||||
| `epochs` | int | `25` | Number of training epochs. |
|
||||
| `imgsz` | int | `640` | Square training image size. |
|
||||
| `batch` | int | `16` | Images per training batch. |
|
||||
| `workers` | int | `2` | DataLoader worker count. |
|
||||
| `patience` | int | `20` | Early stopping patience. |
|
||||
| `device` | string | auto | Optional Ultralytics device value such as `0` or `cpu`. |
|
||||
| `data-yaml` | string | auto | Optional path to `data.yaml`; normally discovered from the uploaded dataset root. |
|
||||
| `dataset-dir` | string | `SM_CHANNEL_TRAIN` | Uploaded dataset root mounted by SageMaker. |
|
||||
|
||||
Do not set `dataset-dir` or `model-dir` in normal SageMaker runs. SageMaker sets those automatically through `SM_CHANNEL_TRAIN` and `SM_MODEL_DIR`.
|
||||
92
examples/meter-detection/prepare_aihub_inputs.py
Normal file
92
examples/meter-detection/prepare_aihub_inputs.py
Normal file
@@ -0,0 +1,92 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Prepare Qualcomm AI Hub calibration and validation inputs for the meter detector."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png"}
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--dataset-dir",
|
||||
type=Path,
|
||||
default=Path("examples/meter-detection/data/electric-meter-detection"),
|
||||
help="Root of the extracted Roboflow dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--calibration-dir",
|
||||
type=Path,
|
||||
default=Path("examples/meter-detection/data/aihub_calibration"),
|
||||
help="Directory where .npy calibration samples will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-file",
|
||||
type=Path,
|
||||
default=Path("examples/meter-detection/data/inputs.npz"),
|
||||
help="Validation .npz input file for qc-cli ai-hub validate.",
|
||||
)
|
||||
parser.add_argument("--input-name", default="images", help="ONNX input name.")
|
||||
parser.add_argument("--image-size", type=int, default=640, help="Square image size used for ONNX export.")
|
||||
parser.add_argument("--samples", type=int, default=16, help="Number of calibration samples to write.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def preprocess_image(path: Path, image_size: int) -> np.ndarray:
|
||||
"""Apply Ultralytics-style letterboxing and produce an NCHW float32 tensor."""
|
||||
with Image.open(path) as source:
|
||||
image = source.convert("RGB")
|
||||
|
||||
scale = min(image_size / image.width, image_size / image.height)
|
||||
resized_width = round(image.width * scale)
|
||||
resized_height = round(image.height * scale)
|
||||
image = image.resize((resized_width, resized_height), Image.Resampling.BILINEAR)
|
||||
|
||||
canvas = Image.new("RGB", (image_size, image_size), (114, 114, 114))
|
||||
left = round((image_size - resized_width) / 2 - 0.1)
|
||||
top = round((image_size - resized_height) / 2 - 0.1)
|
||||
canvas.paste(image, (left, top))
|
||||
|
||||
array = np.asarray(canvas, dtype=np.float32) / 255.0
|
||||
return np.transpose(array, (2, 0, 1))[None, ...].astype(np.float32)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
if args.image_size < 1:
|
||||
raise SystemExit("--image-size must be at least 1")
|
||||
if args.samples < 1:
|
||||
raise SystemExit("--samples must be at least 1")
|
||||
|
||||
images = sorted(
|
||||
path
|
||||
for path in args.dataset_dir.rglob("*")
|
||||
if path.is_file() and path.suffix.lower() in IMAGE_EXTENSIONS and path.parent.name == "images"
|
||||
)
|
||||
if not images:
|
||||
raise SystemExit(f"No images found under {args.dataset_dir}")
|
||||
|
||||
args.calibration_dir.mkdir(parents=True, exist_ok=True)
|
||||
args.input_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
for stale_sample in args.calibration_dir.glob("sample_*.npy"):
|
||||
stale_sample.unlink()
|
||||
|
||||
prepared: list[np.ndarray] = []
|
||||
for index, image_path in enumerate(images[: args.samples]):
|
||||
sample = preprocess_image(image_path, args.image_size)
|
||||
np.save(args.calibration_dir / f"sample_{index:03d}.npy", sample)
|
||||
prepared.append(sample)
|
||||
|
||||
np.savez(args.input_file, **{args.input_name: prepared[0]}) # pyright: ignore[reportArgumentType]
|
||||
print(f"Wrote {len(prepared)} calibration samples to {args.calibration_dir}")
|
||||
print(f"Wrote validation input to {args.input_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
3
examples/meter-detection/source/requirements.txt
Normal file
3
examples/meter-detection/source/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
ultralytics>=8.3.0
|
||||
pyyaml>=6.0.3
|
||||
onnx>=1.16.0
|
||||
38
examples/meter-detection/source/sanitize_onnx.py
Normal file
38
examples/meter-detection/source/sanitize_onnx.py
Normal file
@@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Remove ONNX value_info entries that duplicate graph inputs or outputs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import onnx # type: ignore[reportMissingImports]
|
||||
|
||||
|
||||
def sanitize_onnx(path: Path, output_path: Path | None = None) -> Path:
|
||||
model = onnx.load(path)
|
||||
io_names = {value.name for value in (*model.graph.input, *model.graph.output)}
|
||||
retained_value_info = [value for value in model.graph.value_info if value.name not in io_names]
|
||||
|
||||
destination = output_path or path
|
||||
if len(retained_value_info) != len(model.graph.value_info):
|
||||
del model.graph.value_info[:]
|
||||
model.graph.value_info.extend(retained_value_info)
|
||||
|
||||
destination.parent.mkdir(parents=True, exist_ok=True)
|
||||
onnx.save(model, destination)
|
||||
return destination
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("onnx_path", type=Path)
|
||||
parser.add_argument("--output", type=Path)
|
||||
args = parser.parse_args()
|
||||
|
||||
written = sanitize_onnx(args.onnx_path, args.output)
|
||||
print(f"Saved sanitized ONNX model to {written}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
128
examples/meter-detection/source/train.py
Normal file
128
examples/meter-detection/source/train.py
Normal file
@@ -0,0 +1,128 @@
|
||||
#!/usr/bin/env python3
|
||||
"""SageMaker entry point for YOLO electric meter detection training."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from sanitize_onnx import sanitize_onnx
|
||||
from training_metrics import write_training_metrics
|
||||
from ultralytics import YOLO # type: ignore[reportMissingImports]
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", default="yolo26n.pt")
|
||||
parser.add_argument("--epochs", type=int, default=25)
|
||||
parser.add_argument("--imgsz", type=int, default=640)
|
||||
parser.add_argument("--batch", type=int, default=16)
|
||||
parser.add_argument("--workers", type=int, default=2)
|
||||
parser.add_argument("--patience", type=int, default=20)
|
||||
parser.add_argument("--device", default=None)
|
||||
parser.add_argument("--data-yaml", default=None)
|
||||
parser.add_argument("--dataset-dir", default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
|
||||
parser.add_argument("--train-dir", dest="dataset_dir", help=argparse.SUPPRESS)
|
||||
parser.add_argument("--model-dir", default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def find_data_yaml(dataset_dir: Path, explicit_path: str | None) -> Path:
|
||||
if explicit_path:
|
||||
data_yaml = Path(explicit_path)
|
||||
if data_yaml.is_file():
|
||||
return data_yaml
|
||||
raise FileNotFoundError(f"Configured data.yaml does not exist: {data_yaml}")
|
||||
|
||||
matches = sorted(dataset_dir.rglob("data.yaml"))
|
||||
if not matches:
|
||||
raise FileNotFoundError(f"Could not find data.yaml under {dataset_dir}")
|
||||
if len(matches) > 1:
|
||||
print(f"Found multiple data.yaml files; using {matches[0]}")
|
||||
return matches[0]
|
||||
|
||||
|
||||
def prepare_data_yaml(data_yaml: Path) -> Path:
|
||||
"""Write a SageMaker-local data file rooted at the uploaded dataset."""
|
||||
dataset_root = data_yaml.parent
|
||||
data = yaml.safe_load(data_yaml.read_text(encoding="utf-8"))
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Expected a mapping in {data_yaml}")
|
||||
|
||||
normalized = dict(data)
|
||||
normalized["path"] = str(dataset_root)
|
||||
if "val" not in normalized and "valid" in normalized:
|
||||
normalized["val"] = normalized.pop("valid")
|
||||
|
||||
prepared_path = dataset_root / "data.sagemaker.yaml"
|
||||
prepared_path.write_text(yaml.safe_dump(normalized, sort_keys=False), encoding="utf-8")
|
||||
print(f"Prepared dataset config: {prepared_path}")
|
||||
return prepared_path
|
||||
|
||||
|
||||
def copy_if_exists(source: Path, destination: Path) -> None:
|
||||
if source.exists():
|
||||
shutil.copy2(source, destination)
|
||||
print(f"Saved {destination}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
dataset_dir = Path(args.dataset_dir)
|
||||
model_dir = Path(args.model_dir)
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
data_yaml = prepare_data_yaml(find_data_yaml(dataset_dir, args.data_yaml))
|
||||
model = YOLO(args.model)
|
||||
|
||||
train_kwargs: dict[str, Any] = {
|
||||
"data": str(data_yaml),
|
||||
"epochs": args.epochs,
|
||||
"imgsz": args.imgsz,
|
||||
"batch": args.batch,
|
||||
"workers": args.workers,
|
||||
"patience": args.patience,
|
||||
"project": str(model_dir / "runs"),
|
||||
"name": "train",
|
||||
"exist_ok": True,
|
||||
}
|
||||
if args.device:
|
||||
train_kwargs["device"] = args.device
|
||||
|
||||
results = model.train(**train_kwargs)
|
||||
save_dir = Path(results.save_dir)
|
||||
best_pt = save_dir / "weights" / "best.pt"
|
||||
last_pt = save_dir / "weights" / "last.pt"
|
||||
trained_weights = best_pt if best_pt.exists() else last_pt
|
||||
if not trained_weights.exists():
|
||||
raise FileNotFoundError(f"Could not find trained weights in {save_dir / 'weights'}")
|
||||
|
||||
write_training_metrics(save_dir / "results.csv", model_dir / "training_metrics.json")
|
||||
copy_if_exists(trained_weights, model_dir / "best.pt")
|
||||
trained_model = YOLO(str(trained_weights))
|
||||
onnx_path = Path(trained_model.export(format="onnx", imgsz=args.imgsz))
|
||||
saved_onnx_path = sanitize_onnx(onnx_path, model_dir / "model.onnx")
|
||||
print(f"Saved {saved_onnx_path}")
|
||||
|
||||
metrics = {
|
||||
"model": args.model,
|
||||
"epochs": args.epochs,
|
||||
"imgsz": args.imgsz,
|
||||
"batch": args.batch,
|
||||
"workers": args.workers,
|
||||
"patience": args.patience,
|
||||
"data_yaml": str(data_yaml),
|
||||
"weights": str(trained_weights),
|
||||
"onnx": str(saved_onnx_path),
|
||||
}
|
||||
(model_dir / "metrics.json").write_text(json.dumps(metrics, indent=2), encoding="utf-8")
|
||||
print(f"Saved model artifacts to {model_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
82
examples/meter-detection/source/training_metrics.py
Normal file
82
examples/meter-detection/source/training_metrics.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import csv
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
METRIC_NAMES = {
|
||||
"metrics/precision(B)": "val.precision",
|
||||
"metrics/recall(B)": "val.recall",
|
||||
"metrics/mAP50(B)": "val.map50",
|
||||
"metrics/mAP50-95(B)": "val.map50_95",
|
||||
"train/box_loss": "train.box_loss",
|
||||
"train/cls_loss": "train.cls_loss",
|
||||
"train/dfl_loss": "train.dfl_loss",
|
||||
"val/box_loss": "val.box_loss",
|
||||
"val/cls_loss": "val.cls_loss",
|
||||
"val/dfl_loss": "val.dfl_loss",
|
||||
"time": "train.elapsed_seconds",
|
||||
}
|
||||
|
||||
|
||||
def write_training_metrics(results_csv: Path, destination: Path) -> None:
|
||||
steps = _read_metric_steps(results_csv)
|
||||
summary = _build_summary(steps)
|
||||
payload = {
|
||||
"schema_version": 1,
|
||||
"steps": steps,
|
||||
"summary": summary,
|
||||
}
|
||||
destination.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
||||
print(f"Saved {destination}")
|
||||
|
||||
|
||||
def _read_metric_steps(results_csv: Path) -> list[dict[str, Any]]:
|
||||
if not results_csv.is_file():
|
||||
raise FileNotFoundError(f"Could not find Ultralytics metrics history: {results_csv}")
|
||||
|
||||
steps: list[dict[str, Any]] = []
|
||||
with results_csv.open(newline="", encoding="utf-8") as csv_file:
|
||||
for row_index, raw_row in enumerate(csv.DictReader(csv_file)):
|
||||
row = {str(key).strip(): value for key, value in raw_row.items()}
|
||||
raw_epoch = row.pop("epoch", row_index)
|
||||
step = int(float(raw_epoch))
|
||||
metrics: dict[str, float] = {}
|
||||
for source_name, raw_value in row.items():
|
||||
if raw_value is None or not raw_value.strip():
|
||||
continue
|
||||
try:
|
||||
value = float(raw_value)
|
||||
except ValueError:
|
||||
continue
|
||||
if math.isfinite(value):
|
||||
metrics[METRIC_NAMES.get(source_name, _normalize_metric_name(source_name))] = value
|
||||
steps.append({"step": step, "metrics": metrics})
|
||||
return steps
|
||||
|
||||
|
||||
def _build_summary(steps: list[dict[str, Any]]) -> dict[str, float]:
|
||||
if not steps:
|
||||
return {}
|
||||
|
||||
summary: dict[str, float] = {}
|
||||
final_step = steps[-1]
|
||||
summary["summary.final_epoch"] = float(final_step["step"])
|
||||
for name, value in final_step["metrics"].items():
|
||||
summary[f"summary.final.{name}"] = value
|
||||
|
||||
scored_steps = [step for step in steps if "val.map50_95" in step["metrics"]]
|
||||
if scored_steps:
|
||||
best_step = max(scored_steps, key=lambda step: step["metrics"]["val.map50_95"])
|
||||
summary["summary.best_epoch"] = float(best_step["step"])
|
||||
summary["summary.best_val.map50_95"] = best_step["metrics"]["val.map50_95"]
|
||||
if "val.map50" in best_step["metrics"]:
|
||||
summary["summary.best_val.map50"] = best_step["metrics"]["val.map50"]
|
||||
return summary
|
||||
|
||||
|
||||
def _normalize_metric_name(name: str) -> str:
|
||||
normalized = name.replace("/", ".")
|
||||
normalized = re.sub(r"[^A-Za-z0-9_.-]+", "_", normalized)
|
||||
return normalized.strip("._") or "unnamed"
|
||||
@@ -1,90 +0,0 @@
|
||||
# SageMaker Training Example
|
||||
|
||||
This example downloads a small image-classification dataset, uploads it through `qc-cli`, and submits a live SageMaker training job.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- AWS credentials configured for the profile in `config.yaml`
|
||||
- Infrastructure already deployed with `qc-cli infra setup`
|
||||
- `config.yaml` updated with:
|
||||
|
||||
```yaml
|
||||
s3:
|
||||
bucket: your-bucket-name
|
||||
|
||||
sagemaker:
|
||||
role_name: <role-name>
|
||||
training:
|
||||
image_uri: 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.6-cpu-py312-ubuntu22.04-sagemaker-v1
|
||||
instance_type: ml.m4.xlarge
|
||||
instance_count: 1
|
||||
source_dir: examples/training/source
|
||||
entry_point: train.py
|
||||
hyperparameters:
|
||||
epochs: 1
|
||||
batch-size: 32
|
||||
learning-rate: 0.001
|
||||
image-size: 160
|
||||
validation-split: 0.2
|
||||
```
|
||||
|
||||
## Training Hyperparameters
|
||||
|
||||
Values under `sagemaker.training.hyperparameters` are passed to the training entry point as command-line arguments. For this example, they map to arguments defined in [source/train.py](source/train.py).
|
||||
|
||||
Supported by this example:
|
||||
|
||||
| Name | Type | Default | Description |
|
||||
|---|---:|---:|---|
|
||||
| `epochs` | int | `1` | Number of training epochs. |
|
||||
| `batch-size` | int | `32` | Images per training batch. |
|
||||
| `learning-rate` | float | `0.001` | Adam optimizer learning rate. |
|
||||
| `image-size` | int | `160` | Resize images to square `image-size x image-size`. |
|
||||
| `validation-split` | float | `0.2` | Fraction of data used for validation. |
|
||||
| `max-samples` | int | `0` | Optional cap for smoke tests; `0` means use all images. |
|
||||
| `seed` | int | `13` | Random seed for reproducible splitting. |
|
||||
| `num-workers` | int | `2` | DataLoader worker count. |
|
||||
|
||||
Do not set `train-dir` or `model-dir` in normal SageMaker runs. SageMaker sets those automatically through `SM_CHANNEL_TRAIN` and `SM_MODEL_DIR`.
|
||||
|
||||
## 1. Download The Dataset
|
||||
|
||||
```bash
|
||||
bash examples/training/download_flower_photos.sh
|
||||
```
|
||||
|
||||
This creates:
|
||||
|
||||
```text
|
||||
examples/training/data/flower_photos_sagemaker/
|
||||
daisy/
|
||||
dandelion/
|
||||
roses/
|
||||
sunflowers/
|
||||
tulips/
|
||||
```
|
||||
|
||||
## 2. Run Training
|
||||
|
||||
Run the training script and wait until it finishes:
|
||||
|
||||
```bash
|
||||
bash examples/training/run_training.sh --config config.yaml --wait
|
||||
```
|
||||
|
||||
Use a dataset that is already uploaded to `s3.data_prefix`:
|
||||
|
||||
```bash
|
||||
bash examples/training/run_training.sh \
|
||||
--config config.yaml \
|
||||
--skip-upload \
|
||||
--wait
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- The default dataset path is `examples/training/data/flower_photos_sagemaker`.
|
||||
- Uploaded data uses the `s3.bucket` and `s3.data_prefix` values from `config.yaml`.
|
||||
- Training artifacts are written under `s3://<bucket>/<model_prefix>/`.
|
||||
- The SageMaker `model.tar.gz` contains `model.onnx`, `model.pt`, `class_to_idx.json`, and `metrics.json`.
|
||||
- SageMaker packages `examples/training/source`, installs `requirements.txt`, and runs `train.py`.
|
||||
@@ -1,40 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
DATASET_URL="https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
|
||||
DEST_DIR="${1:-examples/training/data}"
|
||||
ARCHIVE_PATH="${DEST_DIR}/flower_photos.tgz"
|
||||
RAW_DATASET_DIR="${DEST_DIR}/flower_photos"
|
||||
DATASET_DIR="${DEST_DIR}/flower_photos_sagemaker"
|
||||
CLASS_NAMES=("daisy" "dandelion" "roses" "sunflowers" "tulips")
|
||||
|
||||
mkdir -p "${DEST_DIR}"
|
||||
|
||||
if [[ -d "${DATASET_DIR}" ]]; then
|
||||
echo "Dataset already exists: ${DATASET_DIR}"
|
||||
echo "Use this path with run_training.py:"
|
||||
echo " ${DATASET_DIR}"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Downloading TensorFlow flower_photos dataset..."
|
||||
if command -v curl >/dev/null 2>&1; then
|
||||
curl -L "${DATASET_URL}" -o "${ARCHIVE_PATH}"
|
||||
elif command -v wget >/dev/null 2>&1; then
|
||||
wget -O "${ARCHIVE_PATH}" "${DATASET_URL}"
|
||||
else
|
||||
echo "Either curl or wget is required." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Extracting dataset..."
|
||||
tar -xzf "${ARCHIVE_PATH}" -C "${DEST_DIR}"
|
||||
|
||||
echo "Preparing SageMaker directory layout..."
|
||||
mkdir -p "${DATASET_DIR}"
|
||||
for class_name in "${CLASS_NAMES[@]}"; do
|
||||
cp -R "${RAW_DATASET_DIR}/${class_name}" "${DATASET_DIR}/${class_name}"
|
||||
done
|
||||
|
||||
echo "Dataset ready: ${DATASET_DIR}"
|
||||
find "${DATASET_DIR}" -mindepth 1 -maxdepth 1 -type d -print | sort
|
||||
@@ -1,111 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
CONFIG_PATH="config.yaml"
|
||||
DATASET_DIR="examples/training/data/flower_photos_sagemaker"
|
||||
WAIT=false
|
||||
SKIP_UPLOAD=false
|
||||
POLL_SECONDS=60
|
||||
|
||||
usage() {
|
||||
cat <<EOF
|
||||
Usage: $0 [options]
|
||||
|
||||
Options:
|
||||
--config PATH Path to qc-cli config file. Default: config.yaml
|
||||
--dataset-dir PATH Dataset directory to upload. Default: ${DATASET_DIR}
|
||||
--skip-upload Train against data already uploaded to s3.data_prefix.
|
||||
--wait Poll until training completes.
|
||||
-h, --help Show this help.
|
||||
EOF
|
||||
}
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--config)
|
||||
CONFIG_PATH="$2"
|
||||
shift 2
|
||||
;;
|
||||
--dataset-dir)
|
||||
DATASET_DIR="$2"
|
||||
shift 2
|
||||
;;
|
||||
--skip-upload)
|
||||
SKIP_UPLOAD=true
|
||||
shift
|
||||
;;
|
||||
--wait)
|
||||
WAIT=true
|
||||
shift
|
||||
;;
|
||||
-h|--help)
|
||||
usage
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1" >&2
|
||||
usage >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [[ ! -f "${CONFIG_PATH}" ]]; then
|
||||
echo "Config not found: ${CONFIG_PATH}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ "${SKIP_UPLOAD}" == false && ! -d "${DATASET_DIR}" ]]; then
|
||||
echo "Dataset not found: ${DATASET_DIR}" >&2
|
||||
echo "Run: bash examples/training/download_flower_photos.sh" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
run() {
|
||||
echo "+ $*"
|
||||
"$@"
|
||||
}
|
||||
|
||||
run uv run qc-cli infra status --config "${CONFIG_PATH}"
|
||||
|
||||
if [[ "${SKIP_UPLOAD}" == false ]]; then
|
||||
run uv run qc-cli upload "${DATASET_DIR}" --config "${CONFIG_PATH}"
|
||||
fi
|
||||
|
||||
TRAIN_OUTPUT="$(uv run qc-cli train start --config "${CONFIG_PATH}")"
|
||||
echo "${TRAIN_OUTPUT}"
|
||||
|
||||
JOB_NAME="$(printf '%s\n' "${TRAIN_OUTPUT}" | grep -Eo 'qc-cli-[0-9]{8}-[0-9]{6}' | tail -n 1)"
|
||||
if [[ -z "${JOB_NAME}" ]]; then
|
||||
echo "Could not find training job name in qc-cli output." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Submitted SageMaker training job: ${JOB_NAME}"
|
||||
|
||||
if [[ "${WAIT}" == false ]]; then
|
||||
run uv run qc-cli train status "${JOB_NAME}" --config "${CONFIG_PATH}"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
while true; do
|
||||
STATUS_OUTPUT="$(uv run qc-cli train status "${JOB_NAME}" --config "${CONFIG_PATH}")"
|
||||
echo "${STATUS_OUTPUT}"
|
||||
|
||||
if printf '%s\n' "${STATUS_OUTPUT}" | grep -q 'Status:.*Completed'; then
|
||||
echo "Training completed successfully."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if printf '%s\n' "${STATUS_OUTPUT}" | grep -q 'Status:.*Failed'; then
|
||||
echo "Training failed." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if printf '%s\n' "${STATUS_OUTPUT}" | grep -q 'Status:.*Stopped'; then
|
||||
echo "Training stopped." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
sleep "${POLL_SECONDS}"
|
||||
done
|
||||
@@ -1 +0,0 @@
|
||||
onnx==1.21.0
|
||||
@@ -1,192 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""SageMaker entry point for CPU image-classification training."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader, Subset, random_split
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
|
||||
class SmallImageClassifier(nn.Module):
|
||||
def __init__(self, class_count: int) -> None:
|
||||
super().__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(3, 16, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2),
|
||||
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2),
|
||||
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2),
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
)
|
||||
self.classifier = nn.Linear(64, class_count)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.features(x)
|
||||
x = torch.flatten(x, 1)
|
||||
return self.classifier(x)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--epochs", type=int, default=1)
|
||||
parser.add_argument("--batch-size", type=int, default=32)
|
||||
parser.add_argument("--learning-rate", type=float, default=0.001)
|
||||
parser.add_argument("--image-size", type=int, default=160)
|
||||
parser.add_argument("--validation-split", type=float, default=0.2)
|
||||
parser.add_argument("--max-samples", type=int, default=0)
|
||||
parser.add_argument("--seed", type=int, default=13)
|
||||
parser.add_argument("--num-workers", type=int, default=2)
|
||||
parser.add_argument("--train-dir", default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
|
||||
parser.add_argument("--model-dir", default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def build_datasets(args: argparse.Namespace) -> tuple[Subset, Subset, dict[str, int]]:
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((args.image_size, args.image_size)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
||||
]
|
||||
)
|
||||
dataset = datasets.ImageFolder(args.train_dir, transform=transform)
|
||||
if len(dataset.classes) < 2:
|
||||
raise ValueError(f"Expected at least two classes in {args.train_dir}. Found: {dataset.classes}")
|
||||
|
||||
if args.max_samples > 0 and args.max_samples < len(dataset):
|
||||
indices = list(range(len(dataset)))
|
||||
random.Random(args.seed).shuffle(indices)
|
||||
dataset = Subset(dataset, indices[: args.max_samples])
|
||||
|
||||
validation_size = max(1, int(len(dataset) * args.validation_split))
|
||||
train_size = len(dataset) - validation_size
|
||||
if train_size < 1:
|
||||
raise ValueError("Not enough images to create a train/validation split.")
|
||||
|
||||
generator = torch.Generator().manual_seed(args.seed)
|
||||
train_dataset, validation_dataset = random_split(dataset, [train_size, validation_size], generator=generator)
|
||||
return train_dataset, validation_dataset, getattr(dataset, "dataset", dataset).class_to_idx
|
||||
|
||||
|
||||
def run_epoch(
|
||||
model: nn.Module,
|
||||
data_loader: DataLoader,
|
||||
criterion: nn.Module,
|
||||
optimizer: torch.optim.Optimizer | None,
|
||||
device: torch.device,
|
||||
) -> tuple[float, float]:
|
||||
training = optimizer is not None
|
||||
model.train(training)
|
||||
|
||||
total_loss = 0.0
|
||||
total_correct = 0
|
||||
total_examples = 0
|
||||
|
||||
for images, labels in data_loader:
|
||||
images = images.to(device)
|
||||
labels = labels.to(device)
|
||||
|
||||
with torch.set_grad_enabled(training):
|
||||
logits = model(images)
|
||||
loss = criterion(logits, labels)
|
||||
|
||||
if training:
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item() * images.size(0)
|
||||
total_correct += (logits.argmax(dim=1) == labels).sum().item()
|
||||
total_examples += images.size(0)
|
||||
|
||||
return total_loss / total_examples, total_correct / total_examples
|
||||
|
||||
|
||||
def export_onnx(model: nn.Module, model_dir: Path, image_size: int) -> None:
|
||||
model.eval()
|
||||
dummy_input = torch.randn(1, 3, image_size, image_size)
|
||||
torch.onnx.export(
|
||||
model,
|
||||
dummy_input,
|
||||
model_dir / "model.onnx",
|
||||
export_params=True,
|
||||
opset_version=17,
|
||||
do_constant_folding=True,
|
||||
input_names=["input"],
|
||||
output_names=["logits"],
|
||||
dynamic_axes={
|
||||
"input": {0: "batch_size"},
|
||||
"logits": {0: "batch_size"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
train_dataset, validation_dataset, class_to_idx = build_datasets(args)
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.num_workers,
|
||||
)
|
||||
validation_loader = DataLoader(
|
||||
validation_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = SmallImageClassifier(class_count=len(class_to_idx)).to(device)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
print(f"Training on {device}. Classes: {sorted(class_to_idx)}")
|
||||
metrics = []
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train_loss, train_accuracy = run_epoch(model, train_loader, criterion, optimizer, device)
|
||||
validation_loss, validation_accuracy = run_epoch(model, validation_loader, criterion, None, device)
|
||||
epoch_metrics = {
|
||||
"epoch": epoch,
|
||||
"train_loss": train_loss,
|
||||
"train_accuracy": train_accuracy,
|
||||
"validation_loss": validation_loss,
|
||||
"validation_accuracy": validation_accuracy,
|
||||
}
|
||||
metrics.append(epoch_metrics)
|
||||
print(json.dumps(epoch_metrics, sort_keys=True))
|
||||
|
||||
model_dir = Path(args.model_dir)
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(
|
||||
{
|
||||
"model_state_dict": model.cpu().state_dict(),
|
||||
"class_to_idx": class_to_idx,
|
||||
"image_size": args.image_size,
|
||||
},
|
||||
model_dir / "model.pt",
|
||||
)
|
||||
export_onnx(model, model_dir, args.image_size)
|
||||
(model_dir / "class_to_idx.json").write_text(json.dumps(class_to_idx, indent=2), encoding="utf-8")
|
||||
(model_dir / "metrics.json").write_text(json.dumps(metrics, indent=2), encoding="utf-8")
|
||||
print(f"Saved model artifacts to {model_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -5,20 +5,18 @@ build-backend = "hatchling.build"
|
||||
[project]
|
||||
name = "qc-cli"
|
||||
version = "0.1.0"
|
||||
description = "CLI for SageMaker ONNX training and Qualcomm AI Hub optimization"
|
||||
description = "CLI for training and deploying models for Qualcomm AI Hub"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"aws-cdk-lib>=2.180.0",
|
||||
"typer==0.25.0",
|
||||
"boto3>=1.34,<1.42",
|
||||
"constructs>=10.0.0",
|
||||
"mlflow>=3.0",
|
||||
"numpy>=1.26",
|
||||
"pydantic>=2.13.3",
|
||||
"pyyaml>=6.0.3",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
mlflow = [
|
||||
"mlflow>=3.0",
|
||||
"qai-hub>=0.49.0",
|
||||
"sagemaker-mlflow>=0.4.0",
|
||||
]
|
||||
|
||||
@@ -31,7 +29,6 @@ packages = ["src"]
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"boto3-stubs[iam,s3,sagemaker]",
|
||||
"pytest>=8.0",
|
||||
"pyright>=1.1.409",
|
||||
"types-PyYAML",
|
||||
"ruff>=0.4",
|
||||
|
||||
@@ -3,13 +3,11 @@ from typing import Any
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from src.infra.provisioning import STACK_NAME
|
||||
|
||||
|
||||
def stack_status(region: str, profile: str) -> dict[str, Any] | None:
|
||||
def stack_status(region: str, profile: str, stack_name: str) -> dict[str, Any] | None:
|
||||
client = boto3.Session(profile_name=profile, region_name=region).client("cloudformation")
|
||||
try:
|
||||
stack = client.describe_stacks(StackName=STACK_NAME)["Stacks"][0]
|
||||
stack = client.describe_stacks(StackName=stack_name)["Stacks"][0]
|
||||
except ClientError as e:
|
||||
message = e.response.get("Error", {}).get("Message", "")
|
||||
if "does not exist" in message:
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, cast
|
||||
|
||||
import boto3
|
||||
@@ -28,3 +31,44 @@ def get_tracking_server_arn(region: str, profile: str, name: str) -> str:
|
||||
if not arn:
|
||||
raise ValueError(f"MLflow tracking server has no ARN: {name}")
|
||||
return str(arn)
|
||||
|
||||
|
||||
def create_presigned_tracking_server_url(region: str, profile: str, name: str) -> str:
|
||||
client = boto3.Session(profile_name=profile, region_name=region).client("sagemaker")
|
||||
response = client.create_presigned_mlflow_tracking_server_url(TrackingServerName=name)
|
||||
return str(response["AuthorizedUrl"])
|
||||
|
||||
|
||||
@contextmanager
|
||||
def tracking_auth_env(profile: str, region: str) -> Generator[None]:
|
||||
credentials = boto3.Session(profile_name=profile, region_name=region).get_credentials()
|
||||
if credentials is None:
|
||||
raise RuntimeError(f"AWS credentials could not be resolved for profile '{profile}'.")
|
||||
|
||||
frozen_credentials = credentials.get_frozen_credentials()
|
||||
if not frozen_credentials.access_key or not frozen_credentials.secret_key:
|
||||
raise RuntimeError(f"AWS credentials are incomplete for profile '{profile}'.")
|
||||
|
||||
env_updates = {
|
||||
"AWS_PROFILE": profile,
|
||||
"AWS_DEFAULT_REGION": region,
|
||||
"AWS_REGION": region,
|
||||
"AWS_ACCESS_KEY_ID": frozen_credentials.access_key,
|
||||
"AWS_SECRET_ACCESS_KEY": frozen_credentials.secret_key,
|
||||
}
|
||||
if frozen_credentials.token:
|
||||
env_updates["AWS_SESSION_TOKEN"] = frozen_credentials.token
|
||||
|
||||
restore_keys = set(env_updates) | {"AWS_SESSION_TOKEN"}
|
||||
previous_env = {key: os.environ.get(key) for key in restore_keys}
|
||||
try:
|
||||
os.environ.update(env_updates)
|
||||
if not frozen_credentials.token:
|
||||
os.environ.pop("AWS_SESSION_TOKEN", None)
|
||||
yield
|
||||
finally:
|
||||
for key, value in previous_env.items():
|
||||
if value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = value
|
||||
|
||||
@@ -21,6 +21,24 @@ def upload_file(
|
||||
return f"s3://{bucket}/{s3_key}"
|
||||
|
||||
|
||||
def download_file(
|
||||
region: str,
|
||||
profile: str,
|
||||
s3_uri: str,
|
||||
local_path: str,
|
||||
) -> str:
|
||||
if not s3_uri.startswith("s3://"):
|
||||
raise ValueError(f"Expected S3 URI, got: {s3_uri}")
|
||||
bucket_key = s3_uri.removeprefix("s3://")
|
||||
bucket, _, key = bucket_key.partition("/")
|
||||
if not bucket or not key:
|
||||
raise ValueError(f"Expected S3 URI with bucket and key, got: {s3_uri}")
|
||||
dest = Path(local_path)
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
_client(region, profile).download_file(bucket, key, str(dest))
|
||||
return str(dest)
|
||||
|
||||
|
||||
def upload_dir(
|
||||
region: str,
|
||||
profile: str,
|
||||
|
||||
@@ -121,6 +121,16 @@ def get_training_job_status(session: Boto3SessionKwargs, job_name: str) -> Train
|
||||
)
|
||||
|
||||
|
||||
def get_model_artifacts(region: str, profile: str, job_name: str) -> str:
|
||||
resp = boto3.Session(profile_name=profile, region_name=region).client("sagemaker").describe_training_job(
|
||||
TrainingJobName=job_name
|
||||
)
|
||||
artifact = resp.get("ModelArtifacts", {}).get("S3ModelArtifacts")
|
||||
if not artifact:
|
||||
raise RuntimeError(f"Training job '{job_name}' does not have model artifacts yet.")
|
||||
return str(artifact)
|
||||
|
||||
|
||||
def list_training_jobs(
|
||||
session: Boto3SessionKwargs,
|
||||
max_results: int = 10,
|
||||
|
||||
1
src/cloud/__init__.py
Normal file
1
src/cloud/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Cloud provider adapters."""
|
||||
77
src/cloud/mlflow.py
Normal file
77
src/cloud/mlflow.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from contextlib import AbstractContextManager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
|
||||
from src.aws import mlflow as aws_mlflow
|
||||
from src.config import Config
|
||||
|
||||
|
||||
class MlflowTrackingBackend(Protocol):
|
||||
@property
|
||||
def provider_name(self) -> str: ...
|
||||
|
||||
@property
|
||||
def profile(self) -> str: ...
|
||||
|
||||
@property
|
||||
def region(self) -> str: ...
|
||||
|
||||
def get_tracking_uri(self, tracking_server_name: str) -> str: ...
|
||||
|
||||
def auth_env(self) -> AbstractContextManager[None]: ...
|
||||
|
||||
def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]: ...
|
||||
|
||||
def training_run_tags(self, training_job: Any) -> dict[str, Any]: ...
|
||||
|
||||
def training_status_params(self, training_job_status: Any) -> dict[str, Any]: ...
|
||||
|
||||
def model_version_tags(self, training_job_status: Any) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AwsMlflowTrackingBackend:
|
||||
profile: str
|
||||
region: str
|
||||
provider_name: str = "aws"
|
||||
|
||||
def get_tracking_uri(self, tracking_server_name: str) -> str:
|
||||
return aws_mlflow.get_tracking_server_arn(self.region, self.profile, tracking_server_name)
|
||||
|
||||
def auth_env(self) -> AbstractContextManager[None]:
|
||||
return aws_mlflow.tracking_auth_env(self.profile, self.region)
|
||||
|
||||
def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]:
|
||||
return {
|
||||
"provider.name": self.provider_name,
|
||||
"provider.region": region,
|
||||
"provider.profile": profile,
|
||||
"sagemaker.role_arn": role_arn,
|
||||
"sagemaker.job_name": training_job.job_name,
|
||||
"sagemaker.training_image": training_job.image_uri,
|
||||
"sagemaker.instance_type": training_job.instance_type,
|
||||
"sagemaker.instance_count": training_job.instance_count,
|
||||
"sagemaker.s3_train_uri": training_job.s3_train_uri,
|
||||
"sagemaker.s3_output_path": training_job.s3_output_path,
|
||||
"sagemaker.entry_point": training_job.entry_point,
|
||||
"sagemaker.source_dir": training_job.source_dir,
|
||||
}
|
||||
|
||||
def training_run_tags(self, training_job: Any) -> dict[str, Any]:
|
||||
return {"sagemaker.job_name": training_job.job_name}
|
||||
|
||||
def training_status_params(self, training_job_status: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"sagemaker.training_status": training_job_status.status,
|
||||
"sagemaker.created_at": training_job_status.created,
|
||||
"sagemaker.modified_at": training_job_status.modified,
|
||||
"sagemaker.model_artifacts": training_job_status.model_artifacts,
|
||||
"sagemaker.failure_reason": training_job_status.failure_reason,
|
||||
}
|
||||
|
||||
def model_version_tags(self, training_job_status: Any) -> dict[str, Any]:
|
||||
return {"sagemaker.job_name": training_job_status.name}
|
||||
|
||||
|
||||
def mlflow_tracking_backend_from_config(cfg: Config) -> MlflowTrackingBackend:
|
||||
return AwsMlflowTrackingBackend(profile=cfg.aws.profile, region=cfg.aws.region)
|
||||
567
src/commands/ai_hub.py
Normal file
567
src/commands/ai_hub.py
Normal file
@@ -0,0 +1,567 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import qai_hub.hub as hub
|
||||
import typer
|
||||
from qai_hub.client import Device
|
||||
|
||||
from src import state as state_ops
|
||||
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
||||
from src.config import Config
|
||||
from src.qualcomm import aihub_jobs
|
||||
from src.qualcomm.artifacts import ResolvedOnnx, resolve_onnx
|
||||
|
||||
app = typer.Typer(help="Optimize, quantize, compile, validate, profile, and download models with Qualcomm Workbench")
|
||||
|
||||
_RUNTIME_EXTENSIONS = {
|
||||
"tflite": "tflite",
|
||||
"qnn_context_binary": "bin",
|
||||
"onnx": "onnx",
|
||||
}
|
||||
|
||||
|
||||
class UploadStep(StrEnum):
|
||||
optimize = "optimize"
|
||||
quantize = "quantize"
|
||||
compile = "compile"
|
||||
validate = "validate"
|
||||
profile = "profile"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResolvedModelSource:
|
||||
model: str | Path
|
||||
model_artifact: str | None = None
|
||||
|
||||
|
||||
def _input_specs(cfg: Config) -> dict[str, tuple[tuple[int, ...], str]]:
|
||||
specs = {name: (tuple(shape), dtype) for name, (shape, dtype) in cfg.aihub.input_specs.items()}
|
||||
if not specs:
|
||||
CONSOLE.print("[red]aihub.input_specs must define at least one input.[/red]")
|
||||
raise typer.Exit(1)
|
||||
return specs
|
||||
|
||||
|
||||
def _load_inputs(
|
||||
input_file: Path,
|
||||
specs: Mapping[str, tuple[Sequence[int], str]],
|
||||
input_name: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
import numpy as np
|
||||
|
||||
if not input_file.exists():
|
||||
raise FileNotFoundError(f"File not found: {input_file}")
|
||||
|
||||
if input_file.suffix == ".npz":
|
||||
loaded = np.load(input_file)
|
||||
missing = set(specs) - set(loaded.files)
|
||||
if missing:
|
||||
raise ValueError(f"Missing input(s) in NPZ: {', '.join(sorted(missing))}")
|
||||
return {name: loaded[name] for name in specs}
|
||||
|
||||
if input_file.suffix == ".npy":
|
||||
if input_name is None:
|
||||
if len(specs) != 1:
|
||||
raise ValueError("--input-name is required when config has multiple inputs")
|
||||
input_name = next(iter(specs))
|
||||
if input_name not in specs:
|
||||
raise ValueError(f"Input name '{input_name}' is not defined in aihub.input_specs")
|
||||
return {input_name: np.load(input_file)}
|
||||
|
||||
raise ValueError("Input file must be .npz or .npy")
|
||||
|
||||
|
||||
def _load_calibration(path: Path, specs: Mapping[str, tuple[Sequence[int], str]]) -> dict[str, Any]:
|
||||
import numpy as np
|
||||
|
||||
if path.is_file():
|
||||
return _load_inputs(path, specs)
|
||||
|
||||
if not path.is_dir():
|
||||
raise FileNotFoundError(f"Calibration path not found: {path}")
|
||||
|
||||
if len(specs) != 1:
|
||||
raise ValueError("Directory calibration data is supported only for single-input models.")
|
||||
input_name = next(iter(specs))
|
||||
samples = [np.load(p) for p in sorted(path.glob("*.npy"))]
|
||||
if not samples:
|
||||
raise ValueError(f"No .npy calibration samples found in {path}")
|
||||
return {input_name: samples}
|
||||
|
||||
|
||||
def _job_name(cfg: Config, operation: str) -> str | None:
|
||||
if not cfg.aihub.job_name:
|
||||
return None
|
||||
return f"{cfg.aihub.job_name}-{operation}"
|
||||
|
||||
|
||||
def _model_id_or_state(config_path: str, model_id: str | None, *, quantized: bool = False) -> str:
|
||||
st = state_ops.store(config_path)
|
||||
resolved = model_id or (st.get_last_quantized_model_id() if quantized else st.get_last_compiled_model_id())
|
||||
if not resolved:
|
||||
source = "quantized" if quantized else "compiled"
|
||||
CONSOLE.print(f"[red]No {source} model found. Pass --model-id or run the previous AI Hub step first.[/red]")
|
||||
raise typer.Exit(1)
|
||||
return resolved
|
||||
|
||||
|
||||
def _resolve_model_source(
|
||||
cfg: Config,
|
||||
config_path: str,
|
||||
*,
|
||||
model_id: str | None = None,
|
||||
previous_model_id: str | None = None,
|
||||
from_job: str | None = None,
|
||||
model_s3_uri: str | None = None,
|
||||
onnx_path: str | None = None,
|
||||
) -> ResolvedModelSource:
|
||||
if model_id:
|
||||
return ResolvedModelSource(model_id)
|
||||
|
||||
has_explicit_source = bool(from_job or model_s3_uri or onnx_path)
|
||||
if previous_model_id and not has_explicit_source:
|
||||
return ResolvedModelSource(previous_model_id)
|
||||
|
||||
resolved = _resolve_onnx_source(
|
||||
cfg,
|
||||
config_path,
|
||||
from_job=from_job,
|
||||
model_s3_uri=model_s3_uri,
|
||||
onnx_path=onnx_path,
|
||||
)
|
||||
return ResolvedModelSource(resolved.onnx_path, resolved.model_artifact)
|
||||
|
||||
|
||||
def _resolve_onnx_source(
|
||||
cfg: Config,
|
||||
config_path: str,
|
||||
*,
|
||||
from_job: str | None = None,
|
||||
model_s3_uri: str | None = None,
|
||||
onnx_path: str | None = None,
|
||||
) -> ResolvedOnnx:
|
||||
st = state_ops.store(config_path)
|
||||
last_training_job = st.get_last_training_job()
|
||||
saved_model_artifact = None
|
||||
if not from_job and not model_s3_uri and not onnx_path and not last_training_job:
|
||||
saved_model_artifact = st.get_last_model_artifact()
|
||||
|
||||
return resolve_onnx(
|
||||
cfg=cfg,
|
||||
output_dir=cfg.aihub.output_dir,
|
||||
from_job=from_job,
|
||||
model_s3_uri=model_s3_uri or saved_model_artifact,
|
||||
onnx_path=onnx_path,
|
||||
last_training_job=last_training_job,
|
||||
)
|
||||
|
||||
|
||||
def _device_selector(device: Device) -> str:
|
||||
parts: list[str] = []
|
||||
if device.name:
|
||||
parts.append(f"name={device.name!r}")
|
||||
if device.os:
|
||||
parts.append(f"os={device.os!r}")
|
||||
if device.attributes:
|
||||
parts.append(f"attributes={device.attributes!r}")
|
||||
return ", ".join(parts) if parts else "empty selector"
|
||||
|
||||
|
||||
def _validate_device(cfg: Config) -> None:
|
||||
device = cfg.aihub.device
|
||||
try:
|
||||
matches = hub.get_devices(name=device.name, os=device.os, attributes=device.attributes)
|
||||
except Exception as e:
|
||||
CONSOLE.print(f"[red]Unable to validate AI Hub device {_device_selector(device)}: {e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if matches:
|
||||
return
|
||||
|
||||
CONSOLE.print(f"[red]AI Hub device not found: {_device_selector(device)}[/red]")
|
||||
CONSOLE.print("Run [bold]qai-hub list-devices[/bold] to see valid device names.")
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
def _quantize_step(
|
||||
cfg: Config,
|
||||
config_path: str,
|
||||
calibration_path: Path,
|
||||
*,
|
||||
model_id: str | None = None,
|
||||
from_job: str | None = None,
|
||||
model_s3_uri: str | None = None,
|
||||
onnx_path: str | None = None,
|
||||
) -> str:
|
||||
st = state_ops.store(config_path)
|
||||
specs = _input_specs(cfg)
|
||||
try:
|
||||
source = _resolve_model_source(
|
||||
cfg,
|
||||
config_path,
|
||||
model_id=model_id,
|
||||
previous_model_id=st.get_last_optimized_model_id(),
|
||||
from_job=from_job,
|
||||
model_s3_uri=model_s3_uri,
|
||||
onnx_path=onnx_path,
|
||||
)
|
||||
calibration_data = _load_calibration(calibration_path, specs)
|
||||
except (FileNotFoundError, ValueError) as e:
|
||||
CONSOLE.print(f"[red]{e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
try:
|
||||
hub_model = (
|
||||
hub.upload_model(str(source.model), name=cfg.aihub.model_name)
|
||||
if isinstance(source.model, Path)
|
||||
else hub.get_model(source.model)
|
||||
)
|
||||
result = aihub_jobs.submit_quantize_job(
|
||||
hub_model,
|
||||
calibration_data,
|
||||
cfg.aihub.quantize_options,
|
||||
job_name=_job_name(cfg, "quantize"),
|
||||
)
|
||||
except Exception as e:
|
||||
CONSOLE.print(f"[red]AI Hub quantize failed: {e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
updates: dict[str, Any] = {
|
||||
"last_quantize_job_id": result["job_id"],
|
||||
"last_quantized_model_id": result["model_id"],
|
||||
}
|
||||
if source.model_artifact:
|
||||
updates["last_model_artifact"] = source.model_artifact
|
||||
st.update(**updates)
|
||||
CONSOLE.print(f"[green]✓[/green] Quantize job: [bold]{result['job_id']}[/bold]")
|
||||
CONSOLE.print(f"[green]✓[/green] Quantized model: [bold]{result['model_id']}[/bold]")
|
||||
return str(result["model_id"])
|
||||
|
||||
|
||||
def _optimize_step(
|
||||
cfg: Config,
|
||||
config_path: str,
|
||||
from_job: str | None,
|
||||
model_s3_uri: str | None,
|
||||
onnx_path: str | None,
|
||||
) -> str:
|
||||
st = state_ops.store(config_path)
|
||||
_validate_device(cfg)
|
||||
specs = _input_specs(cfg)
|
||||
try:
|
||||
source = _resolve_onnx_source(
|
||||
cfg,
|
||||
config_path,
|
||||
from_job=from_job,
|
||||
model_s3_uri=model_s3_uri,
|
||||
onnx_path=onnx_path,
|
||||
)
|
||||
except (FileNotFoundError, ValueError) as e:
|
||||
CONSOLE.print(f"[red]{e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
try:
|
||||
hub_model = hub.upload_model(str(source.onnx_path), name=cfg.aihub.model_name)
|
||||
result = aihub_jobs.submit_compile_job(
|
||||
model=hub_model,
|
||||
device=cfg.aihub.device,
|
||||
input_specs=specs,
|
||||
target_runtime="onnx",
|
||||
job_name=_job_name(cfg, "optimize"),
|
||||
)
|
||||
except Exception as e:
|
||||
CONSOLE.print(f"[red]AI Hub ONNX optimization failed: {e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
st.update(
|
||||
last_model_artifact=source.model_artifact,
|
||||
last_optimize_job_id=result["job_id"],
|
||||
last_optimized_model_id=result["model_id"],
|
||||
)
|
||||
CONSOLE.print(f"[green]✓[/green] ONNX optimization job: [bold]{result['job_id']}[/bold]")
|
||||
CONSOLE.print(f"[green]✓[/green] Optimized ONNX model: [bold]{result['model_id']}[/bold]")
|
||||
return str(result["model_id"])
|
||||
|
||||
|
||||
def _compile_step(
|
||||
cfg: Config,
|
||||
config_path: str,
|
||||
*,
|
||||
model_id: str | None = None,
|
||||
from_job: str | None = None,
|
||||
model_s3_uri: str | None = None,
|
||||
onnx_path: str | None = None,
|
||||
) -> str:
|
||||
st = state_ops.store(config_path)
|
||||
_validate_device(cfg)
|
||||
specs = _input_specs(cfg)
|
||||
try:
|
||||
source = _resolve_model_source(
|
||||
cfg,
|
||||
config_path,
|
||||
model_id=model_id,
|
||||
previous_model_id=st.get_last_quantized_model_id(),
|
||||
from_job=from_job,
|
||||
model_s3_uri=model_s3_uri,
|
||||
onnx_path=onnx_path,
|
||||
)
|
||||
except (FileNotFoundError, ValueError) as e:
|
||||
CONSOLE.print(f"[red]{e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
try:
|
||||
hub_model = (
|
||||
hub.upload_model(str(source.model), name=cfg.aihub.model_name)
|
||||
if isinstance(source.model, Path)
|
||||
else hub.get_model(source.model)
|
||||
)
|
||||
result = aihub_jobs.submit_compile_job(
|
||||
model=hub_model,
|
||||
device=cfg.aihub.device,
|
||||
input_specs=specs,
|
||||
target_runtime=cfg.aihub.target_runtime,
|
||||
options=cfg.aihub.compile_options,
|
||||
job_name=_job_name(cfg, "compile"),
|
||||
)
|
||||
except Exception as e:
|
||||
CONSOLE.print(f"[red]AI Hub compile failed: {e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
updates: dict[str, Any] = {
|
||||
"last_compile_job_id": result["job_id"],
|
||||
"last_compiled_model_id": result["model_id"],
|
||||
}
|
||||
if source.model_artifact:
|
||||
updates["last_model_artifact"] = source.model_artifact
|
||||
st.update(**updates)
|
||||
CONSOLE.print(f"[green]✓[/green] Compile job: [bold]{result['job_id']}[/bold]")
|
||||
CONSOLE.print(f"[green]✓[/green] Compiled model: [bold]{result['model_id']}[/bold]")
|
||||
return str(result["model_id"])
|
||||
|
||||
|
||||
def _validate_step(
|
||||
cfg: Config,
|
||||
config_path: str,
|
||||
input_file: Path,
|
||||
model_id: str | None,
|
||||
input_name: str | None,
|
||||
) -> str:
|
||||
_validate_device(cfg)
|
||||
specs = _input_specs(cfg)
|
||||
resolved_model_id = _model_id_or_state(config_path, model_id)
|
||||
try:
|
||||
inputs = _load_inputs(input_file, specs, input_name)
|
||||
except (FileNotFoundError, ValueError) as e:
|
||||
CONSOLE.print(f"[red]{e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
run = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
out_dir = Path(cfg.aihub.output_dir) / run / "validation"
|
||||
try:
|
||||
hub_model = hub.get_model(resolved_model_id)
|
||||
result = aihub_jobs.submit_inference_job(
|
||||
hub_model,
|
||||
cfg.aihub.device,
|
||||
inputs,
|
||||
out_dir,
|
||||
job_name=_job_name(cfg, "validate"),
|
||||
)
|
||||
except Exception as e:
|
||||
CONSOLE.print(f"[red]AI Hub inference failed: {e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
state_ops.store(config_path).update(last_inference_job_id=result["job_id"])
|
||||
CONSOLE.print(f"[green]✓[/green] Inference job: [bold]{result['job_id']}[/bold]")
|
||||
outputs = result.get("outputs")
|
||||
if isinstance(outputs, dict):
|
||||
for name, value in outputs.items():
|
||||
CONSOLE.print(f" {name}: shape={getattr(value, 'shape', '?')}")
|
||||
CONSOLE.print(f"Outputs: [cyan]{out_dir}[/cyan]")
|
||||
return str(result["job_id"])
|
||||
|
||||
|
||||
def _profile_step(cfg: Config, config_path: str, model_id: str | None) -> str:
|
||||
_validate_device(cfg)
|
||||
resolved_model_id = _model_id_or_state(config_path, model_id)
|
||||
try:
|
||||
hub_model = hub.get_model(resolved_model_id)
|
||||
result = aihub_jobs.submit_profile_job(
|
||||
hub_model,
|
||||
cfg.aihub.device,
|
||||
cfg.aihub.profile_options,
|
||||
job_name=_job_name(cfg, "profile"),
|
||||
)
|
||||
except Exception as e:
|
||||
CONSOLE.print(f"[red]AI Hub profile failed: {e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
state_ops.store(config_path).update(last_profile_job_id=result["job_id"])
|
||||
CONSOLE.print(f"[green]✓[/green] Profile job: [bold]{result['job_id']}[/bold]")
|
||||
return str(result["job_id"])
|
||||
|
||||
|
||||
@app.command()
|
||||
def optimize(
|
||||
from_job: str | None = typer.Option(None, "--from-job", help="Training job name whose model artifact should optimize"),
|
||||
model_s3_uri: str | None = typer.Option(None, "--model-s3-uri", help="S3 URI of model.tar.gz to optimize"),
|
||||
onnx_path: str | None = typer.Option(
|
||||
None, "--onnx-path", help="Local ONNX path or ONNX path inside extracted artifact"
|
||||
),
|
||||
config: str = CONFIG_OPT,
|
||||
) -> None:
|
||||
"""Optimize an external model into a Workbench-produced ONNX model."""
|
||||
cfg = load_cfg(config)
|
||||
_optimize_step(cfg, config, from_job, model_s3_uri, onnx_path)
|
||||
|
||||
|
||||
@app.command()
|
||||
def quantize(
|
||||
calibration_path: Path = typer.Argument(..., help="Calibration .npz file or directory of .npy samples"),
|
||||
model_id: str | None = typer.Option(None, "--model-id", help="AI Hub optimized ONNX model ID"),
|
||||
from_job: str | None = typer.Option(None, "--from-job", help="Training job name whose model artifact should quantize"),
|
||||
model_s3_uri: str | None = typer.Option(None, "--model-s3-uri", help="S3 URI of model.tar.gz to quantize"),
|
||||
onnx_path: str | None = typer.Option(
|
||||
None, "--onnx-path", help="Local ONNX path or ONNX path inside extracted artifact"
|
||||
),
|
||||
config: str = CONFIG_OPT,
|
||||
) -> None:
|
||||
"""Quantize an ONNX model to INT8."""
|
||||
cfg = load_cfg(config)
|
||||
_quantize_step(
|
||||
cfg,
|
||||
config,
|
||||
calibration_path,
|
||||
model_id=model_id,
|
||||
from_job=from_job,
|
||||
model_s3_uri=model_s3_uri,
|
||||
onnx_path=onnx_path,
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def compile(
|
||||
model_id: str | None = typer.Option(None, "--model-id", help="AI Hub model ID to compile"),
|
||||
from_job: str | None = typer.Option(None, "--from-job", help="Training job name whose model artifact should compile"),
|
||||
model_s3_uri: str | None = typer.Option(None, "--model-s3-uri", help="S3 URI of model.tar.gz to compile"),
|
||||
onnx_path: str | None = typer.Option(
|
||||
None, "--onnx-path", help="Local ONNX path or ONNX path inside extracted artifact"
|
||||
),
|
||||
config: str = CONFIG_OPT,
|
||||
) -> None:
|
||||
"""Compile a model for the configured Qualcomm AI Hub target."""
|
||||
cfg = load_cfg(config)
|
||||
_compile_step(
|
||||
cfg,
|
||||
config,
|
||||
model_id=model_id,
|
||||
from_job=from_job,
|
||||
model_s3_uri=model_s3_uri,
|
||||
onnx_path=onnx_path,
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def validate(
|
||||
input_file: Path = typer.Argument(..., help="NumPy .npz or .npy inputs to run on device"),
|
||||
model_id: str | None = typer.Option(None, "--model-id", help="AI Hub compiled model ID"),
|
||||
input_name: str | None = typer.Option(None, "--input-name", help="Input name for .npy files"),
|
||||
config: str = CONFIG_OPT,
|
||||
) -> None:
|
||||
"""Run an AI Hub inference job using sample inputs."""
|
||||
cfg = load_cfg(config)
|
||||
_validate_step(cfg, config, input_file, model_id, input_name)
|
||||
|
||||
|
||||
@app.command()
|
||||
def profile(
|
||||
model_id: str | None = typer.Option(None, "--model-id", help="AI Hub compiled model ID"),
|
||||
config: str = CONFIG_OPT,
|
||||
) -> None:
|
||||
"""Profile a compiled model on the configured AI Hub device."""
|
||||
cfg = load_cfg(config)
|
||||
_profile_step(cfg, config, model_id)
|
||||
|
||||
|
||||
@app.command()
|
||||
def upload(
|
||||
calibration_path: Path = typer.Argument(..., help="Calibration .npz file or directory of .npy samples"),
|
||||
input_file: Path = typer.Argument(..., help="Validation .npz or .npy inputs to run on device"),
|
||||
from_step: UploadStep = typer.Option(UploadStep.optimize, "--from-step", help="Resume from this Workbench step"),
|
||||
from_job: str | None = typer.Option(None, "--from-job", help="Training job name whose model artifact should upload"),
|
||||
model_s3_uri: str | None = typer.Option(None, "--model-s3-uri", help="S3 URI of model.tar.gz to upload"),
|
||||
onnx_path: str | None = typer.Option(
|
||||
None, "--onnx-path", help="Local ONNX path or ONNX path inside extracted artifact"
|
||||
),
|
||||
input_name: str | None = typer.Option(None, "--input-name", help="Input name for .npy validation files"),
|
||||
config: str = CONFIG_OPT,
|
||||
) -> None:
|
||||
"""Optimize, quantize, optionally compile, validate, and profile a model."""
|
||||
cfg = load_cfg(config)
|
||||
steps = [UploadStep.optimize, UploadStep.quantize, UploadStep.compile, UploadStep.validate, UploadStep.profile]
|
||||
selected = steps[steps.index(from_step) :]
|
||||
|
||||
optimized_model_id: str | None = None
|
||||
quantized_model_id: str | None = None
|
||||
compiled_model_id: str | None = None
|
||||
if UploadStep.optimize in selected:
|
||||
optimized_model_id = _optimize_step(cfg, config, from_job, model_s3_uri, onnx_path)
|
||||
if UploadStep.quantize in selected:
|
||||
if UploadStep.optimize not in selected:
|
||||
optimized_model_id = state_ops.store(config).get_last_optimized_model_id()
|
||||
if not optimized_model_id:
|
||||
CONSOLE.print(
|
||||
"[red]No optimized ONNX model found. Resume from --from-step optimize or run "
|
||||
"'qc-cli ai-hub optimize' first.[/red]"
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
quantized_model_id = _quantize_step(
|
||||
cfg,
|
||||
config,
|
||||
calibration_path,
|
||||
model_id=optimized_model_id,
|
||||
)
|
||||
if UploadStep.compile in selected:
|
||||
if cfg.aihub.target_runtime == "onnx":
|
||||
compiled_model_id = quantized_model_id or state_ops.store(config).get_last_quantized_model_id()
|
||||
if not compiled_model_id:
|
||||
CONSOLE.print(
|
||||
"[red]No quantized ONNX model found. Resume from --from-step quantize or run "
|
||||
"'qc-cli ai-hub quantize' first.[/red]"
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
state_ops.store(config).update(last_compiled_model_id=compiled_model_id)
|
||||
CONSOLE.print("[green]✓[/green] Target runtime is ONNX; skipping final compile.")
|
||||
else:
|
||||
compiled_model_id = _compile_step(
|
||||
cfg,
|
||||
config,
|
||||
model_id=quantized_model_id,
|
||||
)
|
||||
if UploadStep.validate in selected:
|
||||
_validate_step(cfg, config, input_file, compiled_model_id, input_name)
|
||||
if UploadStep.profile in selected:
|
||||
_profile_step(cfg, config, compiled_model_id)
|
||||
|
||||
|
||||
@app.command()
|
||||
def download(
|
||||
model_id: str | None = typer.Option(None, "--model-id", help="AI Hub compiled model ID"),
|
||||
output: Path | None = typer.Option(None, "--output", "-o", help="Destination file path"),
|
||||
config: str = CONFIG_OPT,
|
||||
) -> None:
|
||||
"""Download the last compiled deployable artifact from AI Hub."""
|
||||
cfg = load_cfg(config)
|
||||
resolved_model_id = _model_id_or_state(config, model_id)
|
||||
ext = _RUNTIME_EXTENSIONS.get(cfg.aihub.target_runtime, cfg.aihub.target_runtime)
|
||||
dest = output or (Path(cfg.aihub.output_dir) / f"model.{ext}")
|
||||
|
||||
try:
|
||||
written = aihub_jobs.download_model(resolved_model_id, dest)
|
||||
except Exception as e:
|
||||
CONSOLE.print(f"[red]AI Hub download failed: {e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
state_ops.store(config).update(last_downloaded_model=written)
|
||||
CONSOLE.print(f"[green]✓[/green] Downloaded model: [cyan]{written}[/cyan]")
|
||||
@@ -51,6 +51,8 @@ def setup(
|
||||
profile=cfg.aws.profile,
|
||||
account_id=account_id,
|
||||
region=cfg.aws.region,
|
||||
bootstrap_qualifier=cfg.infra.effective_bootstrap_qualifier,
|
||||
toolkit_stack_name=cfg.infra.effective_toolkit_stack_name,
|
||||
cloudformation_execution_policy=cloudformation_execution_policy,
|
||||
)
|
||||
with CONSOLE.status("Running cdk deploy..."):
|
||||
@@ -58,6 +60,9 @@ def setup(
|
||||
profile=cfg.aws.profile,
|
||||
account_id=account_id,
|
||||
region=cfg.aws.region,
|
||||
stack_name=cfg.infra.stack_name,
|
||||
bootstrap_qualifier=cfg.infra.effective_bootstrap_qualifier,
|
||||
toolkit_stack_name=cfg.infra.effective_toolkit_stack_name,
|
||||
config_path=config,
|
||||
config_dir=str(Path(config).parent),
|
||||
config_snapshot=cfg.model_dump(mode="json"),
|
||||
@@ -72,7 +77,8 @@ def setup(
|
||||
if outputs.get("SageMakerRoleArn"):
|
||||
CONSOLE.print(f"[green]✓[/green] IAM role: {outputs['SageMakerRoleArn']}")
|
||||
if cfg.mlflow.mode is MlflowMode.create and outputs.get("MlflowTrackingServerArn"):
|
||||
CONSOLE.print(f"[green]✓[/green] MLflow: {outputs['MlflowTrackingServerArn']}")
|
||||
mlflow_name = outputs.get("MlflowTrackingServerName", cfg.managed_mlflow_tracking_server_name)
|
||||
CONSOLE.print(f"[green]✓[/green] MLflow: {mlflow_name}")
|
||||
elif cfg.mlflow.mode is MlflowMode.existing:
|
||||
CONSOLE.print(f"[green]✓[/green] MLflow: {cfg.mlflow.tracking_server_name}")
|
||||
CONSOLE.print("\n[bold green]Infrastructure ready.[/bold green]")
|
||||
@@ -82,7 +88,7 @@ def setup(
|
||||
def status(config: str = CONFIG_OPT) -> None:
|
||||
"""Show current infrastructure status."""
|
||||
cfg = load_cfg(config)
|
||||
stack = cloudformation.stack_status(cfg.aws.region, cfg.aws.profile)
|
||||
stack = cloudformation.stack_status(cfg.aws.region, cfg.aws.profile, cfg.infra.stack_name)
|
||||
|
||||
table = Table(title="Infrastructure Status")
|
||||
table.add_column("Resource", style="cyan")
|
||||
@@ -91,13 +97,13 @@ def status(config: str = CONFIG_OPT) -> None:
|
||||
table.add_column("ARN / URI")
|
||||
|
||||
if not stack:
|
||||
table.add_row("CDK Stack", provisioning.STACK_NAME, "[red]missing[/red]", "-")
|
||||
table.add_row("CDK Stack", cfg.infra.stack_name, "[red]missing[/red]", "-")
|
||||
table.add_row("S3 Bucket", cfg.s3.bucket, "[red]unknown[/red]", "-")
|
||||
table.add_row("IAM Role", cfg.sagemaker.role_name, "[red]unknown[/red]", "-")
|
||||
if cfg.mlflow.mode is not MlflowMode.disabled:
|
||||
table.add_row(
|
||||
"MLflow",
|
||||
cfg.mlflow.tracking_server_name or "-",
|
||||
cfg.effective_mlflow_tracking_server_name or "-",
|
||||
"[red]unknown[/red]",
|
||||
"-",
|
||||
)
|
||||
@@ -114,14 +120,14 @@ def status(config: str = CONFIG_OPT) -> None:
|
||||
)
|
||||
table.add_row(
|
||||
"IAM Role",
|
||||
cfg.sagemaker.role_name,
|
||||
_role_name(cfg.sagemaker.role_name, outputs.get("SageMakerRoleArn", "")),
|
||||
"[green]managed[/green]",
|
||||
outputs.get("SageMakerRoleArn", "-"),
|
||||
)
|
||||
if cfg.mlflow.mode is MlflowMode.create:
|
||||
table.add_row(
|
||||
"MLflow",
|
||||
cfg.mlflow.tracking_server_name or "-",
|
||||
outputs.get("MlflowTrackingServerName", cfg.managed_mlflow_tracking_server_name),
|
||||
"[green]managed[/green]",
|
||||
outputs.get("MlflowTrackingServerArn", outputs.get("MlflowArtifactUri", "-")),
|
||||
)
|
||||
@@ -156,10 +162,13 @@ def destroy(
|
||||
) -> None:
|
||||
"""Destroy the CDK stack."""
|
||||
cfg = _destroy_config(config)
|
||||
stack_name = _destroy_stack_name(config, cfg)
|
||||
bootstrap_qualifier = _destroy_bootstrap_qualifier(config, cfg)
|
||||
toolkit_stack_name = _destroy_toolkit_stack_name(config, cfg)
|
||||
|
||||
if not yes and not delete_bucket_data:
|
||||
typer.confirm(
|
||||
f"Destroy CDK stack '{provisioning.STACK_NAME}' while retaining S3 bucket data?",
|
||||
f"Destroy CDK stack '{stack_name}' while retaining S3 bucket data?",
|
||||
abort=True,
|
||||
)
|
||||
|
||||
@@ -172,13 +181,17 @@ def destroy(
|
||||
provisioning.destroy(
|
||||
profile=cfg.aws.profile,
|
||||
account_id=account_id,
|
||||
stack_name=stack_name,
|
||||
bootstrap_qualifier=bootstrap_qualifier,
|
||||
toolkit_stack_name=toolkit_stack_name,
|
||||
config_path=str(snapshot_path),
|
||||
delete_bucket_data=delete_bucket_data,
|
||||
)
|
||||
except RuntimeError as e:
|
||||
CONSOLE.print(f"[red]{e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
CONSOLE.print(f"[green]✓[/green] Destroyed stack: {provisioning.STACK_NAME}")
|
||||
CONSOLE.print(f"[green]✓[/green] Destroyed stack: {stack_name}")
|
||||
CONSOLE.print(f"[yellow]CDK bootstrap stack retained: {toolkit_stack_name}[/yellow]")
|
||||
|
||||
|
||||
def _destroy_config(config_path: str) -> Config:
|
||||
@@ -190,6 +203,14 @@ def _destroy_config(config_path: str) -> Config:
|
||||
return load_cfg(config_path)
|
||||
|
||||
|
||||
def _role_name(configured_name: str, role_arn: str) -> str:
|
||||
if configured_name:
|
||||
return configured_name
|
||||
if role_arn:
|
||||
return role_arn.rsplit("/", 1)[-1]
|
||||
return "-"
|
||||
|
||||
|
||||
def _destroy_account_id(config_path: str, cfg: Config) -> str:
|
||||
config_dir = str(Path(config_path).parent)
|
||||
state = read_infra_state(config_dir)
|
||||
@@ -197,3 +218,30 @@ def _destroy_account_id(config_path: str, cfg: Config) -> str:
|
||||
if account_id:
|
||||
return str(account_id)
|
||||
return identity.account_id(cfg.aws.region, cfg.aws.profile)
|
||||
|
||||
|
||||
def _destroy_stack_name(config_path: str, cfg: Config) -> str:
|
||||
config_dir = str(Path(config_path).parent)
|
||||
state = read_infra_state(config_dir)
|
||||
stack_name = state.get("stack_name")
|
||||
if stack_name:
|
||||
return str(stack_name)
|
||||
return cfg.infra.stack_name
|
||||
|
||||
|
||||
def _destroy_bootstrap_qualifier(config_path: str, cfg: Config) -> str:
|
||||
config_dir = str(Path(config_path).parent)
|
||||
state = read_infra_state(config_dir)
|
||||
bootstrap_qualifier = state.get("bootstrap_qualifier")
|
||||
if bootstrap_qualifier:
|
||||
return str(bootstrap_qualifier)
|
||||
return cfg.infra.effective_bootstrap_qualifier
|
||||
|
||||
|
||||
def _destroy_toolkit_stack_name(config_path: str, cfg: Config) -> str:
|
||||
config_dir = str(Path(config_path).parent)
|
||||
state = read_infra_state(config_dir)
|
||||
toolkit_stack_name = state.get("toolkit_stack_name")
|
||||
if toolkit_stack_name:
|
||||
return str(toolkit_stack_name)
|
||||
return cfg.infra.effective_toolkit_stack_name
|
||||
|
||||
40
src/commands/init.py
Normal file
40
src/commands/init.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
import yaml
|
||||
|
||||
from src.commands.utils import CONSOLE
|
||||
from src.config import GENERATED_STACK_PREFIX, Config, InfraConfig, S3Config
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def init(
|
||||
output: str = typer.Option("config.yaml", help="Destination path for the config file"),
|
||||
force: bool = typer.Option(False, "--force", "-f", help="Overwrite an existing config file"),
|
||||
) -> None:
|
||||
"""Write a starter config.yaml to the current directory."""
|
||||
dest = Path(output)
|
||||
if dest.exists() and not force:
|
||||
CONSOLE.print(f"[yellow]{dest} already exists.[/yellow] Use --force to overwrite.")
|
||||
raise typer.Exit(1)
|
||||
|
||||
config = _new_isolated_config()
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
config_data = config.model_dump(mode="json")
|
||||
config_data["sagemaker"].pop("role_name", None)
|
||||
with open(dest, "w") as f:
|
||||
yaml.safe_dump(config_data, f, sort_keys=False)
|
||||
|
||||
CONSOLE.print(f"[green]✓[/green] Config written to [bold]{dest}[/bold]")
|
||||
CONSOLE.print("Edit [cyan]sagemaker.training.image_uri[/cyan] before running training commands.")
|
||||
|
||||
|
||||
def _new_isolated_config() -> Config:
|
||||
suffix = secrets.token_hex(6)
|
||||
namespace = f"{GENERATED_STACK_PREFIX}{suffix}"
|
||||
config = Config(infra=InfraConfig(stack_name=namespace))
|
||||
config.s3 = S3Config(bucket=f"{namespace}-data")
|
||||
return config
|
||||
95
src/commands/mlflow.py
Normal file
95
src/commands/mlflow.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import webbrowser
|
||||
|
||||
import typer
|
||||
|
||||
from src import state as state_ops
|
||||
from src.aws import mlflow as aws_mlflow
|
||||
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
||||
from src.config import MlflowMode
|
||||
from src.tracking.upload import upload_training_metrics
|
||||
|
||||
app = typer.Typer(help="Manage MLflow tracking server access")
|
||||
|
||||
|
||||
@app.command(name="open")
|
||||
def open_mlflow(config: str = CONFIG_OPT) -> None:
|
||||
"""Open a presigned URL for the configured MLflow tracking server."""
|
||||
cfg = load_cfg(config)
|
||||
tracking_server_name = cfg.effective_mlflow_tracking_server_name
|
||||
if not tracking_server_name:
|
||||
CONSOLE.print("[red]MLflow is disabled in config.yaml.[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
try:
|
||||
url = aws_mlflow.create_presigned_tracking_server_url(
|
||||
cfg.aws.region,
|
||||
cfg.aws.profile,
|
||||
tracking_server_name,
|
||||
)
|
||||
except Exception as e:
|
||||
CONSOLE.print("[yellow]Could not create a SageMaker MLflow UI URL.[/yellow]")
|
||||
CONSOLE.print(f"Tracking server: [cyan]{tracking_server_name}[/cyan]")
|
||||
CONSOLE.print(f"Reason: {e}")
|
||||
CONSOLE.print(
|
||||
"This command can create presigned URLs only for MLflow tracking servers managed by "
|
||||
"Amazon SageMaker. If this is an external MLflow server, open it with that server's own URL."
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
CONSOLE.print(f"MLflow tracking server: [cyan]{tracking_server_name}[/cyan]")
|
||||
CONSOLE.print(f"MLflow UI: {url}")
|
||||
if webbrowser.open(url):
|
||||
CONSOLE.print("[green]✓[/green] Opened MLflow UI in your browser.")
|
||||
else:
|
||||
CONSOLE.print("[yellow]Could not open a browser automatically. Open the URL above manually.[/yellow]")
|
||||
|
||||
|
||||
@app.command(name="upload-metrics")
|
||||
def upload_metrics(
|
||||
job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"),
|
||||
force: bool = typer.Option(False, "--force", help="Upload again even if this job is marked as uploaded"),
|
||||
config: str = CONFIG_OPT,
|
||||
) -> None:
|
||||
"""Upload a completed training job's metric history to MLflow."""
|
||||
cfg = load_cfg(config)
|
||||
if cfg.mlflow.mode is MlflowMode.disabled:
|
||||
CONSOLE.print("[red]MLflow is disabled in config.yaml.[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
st = state_ops.store(config)
|
||||
if not job_name:
|
||||
job_name = st.get_last_training_job()
|
||||
if not job_name:
|
||||
CONSOLE.print(
|
||||
"[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]"
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
if st.get_training_job(job_name).get("mlflow_metrics_uploaded") and not force:
|
||||
CONSOLE.print(f"[green]✓[/green] Metrics already uploaded for [cyan]{job_name}[/cyan].")
|
||||
return
|
||||
|
||||
try:
|
||||
result = upload_training_metrics(
|
||||
job_name=job_name,
|
||||
config_path=config,
|
||||
cfg=cfg,
|
||||
force=force,
|
||||
)
|
||||
except Exception as e:
|
||||
CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if result.metrics_history_uploaded:
|
||||
CONSOLE.print(f"[green]✓[/green] Uploaded training metrics for [cyan]{job_name}[/cyan].")
|
||||
else:
|
||||
CONSOLE.print(
|
||||
f"[yellow]No training_metrics.json was found in the SageMaker model artifact for "
|
||||
f"[cyan]{job_name}[/cyan]. Uploaded SageMaker final metrics only.[/yellow]"
|
||||
)
|
||||
CONSOLE.print(f"MLflow run: [cyan]{result.run_id}[/cyan]")
|
||||
if result.registered_model_version:
|
||||
CONSOLE.print(
|
||||
f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] "
|
||||
"([cyan]experiment-latest[/cyan])"
|
||||
)
|
||||
@@ -1,4 +1,6 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
from rich.table import Table
|
||||
@@ -7,7 +9,10 @@ from src import state as state_ops
|
||||
from src.aws import iam
|
||||
from src.aws import sagemaker as sm_ops
|
||||
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
||||
from src.config import Config, MlflowMode
|
||||
from src.infra.state import read_infra_state
|
||||
from src.tracking.mlflow import MlflowTracker
|
||||
from src.tracking.upload import upload_training_metrics
|
||||
|
||||
app = typer.Typer(help="Manage SageMaker training jobs")
|
||||
|
||||
@@ -18,6 +23,8 @@ _STATUS_COLOR = {
|
||||
"Stopping": "yellow",
|
||||
"Stopped": "dim",
|
||||
}
|
||||
_TERMINAL_STATUSES = {"Completed", "Failed", "Stopped"}
|
||||
DEFAULT_POLL_INTERVAL_SECONDS = 30
|
||||
|
||||
|
||||
def _tracker(cfg):
|
||||
@@ -28,11 +35,117 @@ def _tracker(cfg):
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
def _config_dir(config_path: str) -> str:
|
||||
return str(Path(config_path).parent)
|
||||
|
||||
|
||||
def _sagemaker_role_arn(config_path: str, cfg: Config) -> str:
|
||||
state = read_infra_state(_config_dir(config_path))
|
||||
role_arn = state.get("outputs", {}).get("SageMakerRoleArn")
|
||||
if role_arn:
|
||||
return str(role_arn)
|
||||
if cfg.sagemaker.role_name:
|
||||
role_arn = iam.get_role_arn(cfg.aws.profile, cfg.sagemaker.role_name)
|
||||
if role_arn:
|
||||
return role_arn
|
||||
raise RuntimeError(f"IAM role '{cfg.sagemaker.role_name}' not found. Run 'qc-cli infra setup' first.")
|
||||
raise RuntimeError("SageMaker role not found in infra state. Run 'qc-cli infra setup' first.")
|
||||
|
||||
|
||||
def _print_training_status(status: sm_ops.TrainingJobStatus) -> None:
|
||||
color = _STATUS_COLOR.get(status.status, "white")
|
||||
CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]")
|
||||
CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]")
|
||||
if status.created:
|
||||
CONSOLE.print(f"Created: {status.created}")
|
||||
if status.model_artifacts:
|
||||
CONSOLE.print(f"Artifacts: {status.model_artifacts}")
|
||||
if status.failure_reason:
|
||||
CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]")
|
||||
|
||||
|
||||
def _wait_and_upload_metrics(
|
||||
*,
|
||||
job_name: str,
|
||||
poll_interval: int,
|
||||
config_path: str,
|
||||
cfg: Config,
|
||||
) -> None:
|
||||
st = state_ops.store(config_path)
|
||||
previous_status: str | None = None
|
||||
try:
|
||||
while True:
|
||||
training_status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
||||
if training_status.status != previous_status:
|
||||
color = _STATUS_COLOR.get(training_status.status, "white")
|
||||
CONSOLE.print(
|
||||
f"Job [cyan]{training_status.name}[/cyan]: "
|
||||
f"[{color}]{training_status.status}[/{color}]"
|
||||
)
|
||||
previous_status = training_status.status
|
||||
if training_status.status in _TERMINAL_STATUSES:
|
||||
_print_training_status(training_status)
|
||||
if training_status.status != "Completed":
|
||||
raise typer.Exit(1)
|
||||
try:
|
||||
result = upload_training_metrics(
|
||||
job_name=job_name,
|
||||
config_path=config_path,
|
||||
cfg=cfg,
|
||||
)
|
||||
if result.metrics_history_uploaded:
|
||||
CONSOLE.print(
|
||||
f"[green]✓[/green] Uploaded training metrics to MLflow run "
|
||||
f"[cyan]{result.run_id}[/cyan]."
|
||||
)
|
||||
else:
|
||||
CONSOLE.print(
|
||||
"[yellow]No training_metrics.json was found in the SageMaker model artifact. "
|
||||
"Uploaded SageMaker final metrics only.[/yellow]"
|
||||
)
|
||||
if result.registered_model_version:
|
||||
CONSOLE.print(
|
||||
f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] "
|
||||
"([cyan]experiment-latest[/cyan])"
|
||||
)
|
||||
except Exception as e:
|
||||
CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]")
|
||||
CONSOLE.print(
|
||||
f"Retry with [cyan]qc-cli mlflow upload-metrics {job_name}[/cyan]."
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
job_state = st.get_training_job(job_name)
|
||||
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
|
||||
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
|
||||
return
|
||||
time.sleep(poll_interval)
|
||||
except KeyboardInterrupt:
|
||||
CONSOLE.print("[yellow]Stopped waiting. The SageMaker training job is still running.[/yellow]")
|
||||
raise typer.Exit(130)
|
||||
|
||||
|
||||
@app.command()
|
||||
def start(config: str = CONFIG_OPT) -> None:
|
||||
def start(
|
||||
upload_metrics: bool = typer.Option(
|
||||
False,
|
||||
"--upload-metrics",
|
||||
help="Wait for completion, then upload training metrics to MLflow",
|
||||
),
|
||||
poll_interval: int = typer.Option(
|
||||
DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
"--poll-interval",
|
||||
min=1,
|
||||
help="Seconds between status checks when --upload-metrics is used",
|
||||
),
|
||||
config: str = CONFIG_OPT,
|
||||
) -> None:
|
||||
"""Submit a SageMaker training job."""
|
||||
cfg = load_cfg(config)
|
||||
|
||||
if upload_metrics and cfg.mlflow.mode is MlflowMode.disabled:
|
||||
CONSOLE.print("[red]--upload-metrics requires MLflow to be enabled in config.yaml.[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if not cfg.sagemaker.training.image_uri:
|
||||
CONSOLE.print("[red]sagemaker.training.image_uri is required in config.yaml.[/red]")
|
||||
CONSOLE.print(
|
||||
@@ -41,9 +154,10 @@ def start(config: str = CONFIG_OPT) -> None:
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
role_arn = iam.get_role_arn(cfg.aws.profile, cfg.sagemaker.role_name)
|
||||
if not role_arn:
|
||||
CONSOLE.print(f"[red]IAM role '{cfg.sagemaker.role_name}' not found. Run 'qc-cli infra setup' first.[/red]")
|
||||
try:
|
||||
role_arn = _sagemaker_role_arn(config, cfg)
|
||||
except RuntimeError as e:
|
||||
CONSOLE.print(f"[red]{e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
tracker = _tracker(cfg)
|
||||
@@ -68,19 +182,36 @@ def start(config: str = CONFIG_OPT) -> None:
|
||||
|
||||
st = state_ops.store(config)
|
||||
st.set_last_training_job(job_name)
|
||||
run_id = tracker.start_training_run(
|
||||
training_job,
|
||||
region=cfg.aws.region,
|
||||
profile=cfg.aws.profile,
|
||||
role_arn=role_arn,
|
||||
)
|
||||
try:
|
||||
run_id = tracker.start_training_run(
|
||||
training_job,
|
||||
region=cfg.aws.region,
|
||||
profile=cfg.aws.profile,
|
||||
role_arn=role_arn,
|
||||
)
|
||||
except Exception as e:
|
||||
run_id = None
|
||||
CONSOLE.print(f"[yellow]MLflow run creation failed: {e}[/yellow]")
|
||||
CONSOLE.print(
|
||||
"The SageMaker job is still running. Upload metrics after completion with "
|
||||
f"[cyan]qc-cli mlflow upload-metrics {job_name}[/cyan]."
|
||||
)
|
||||
if run_id:
|
||||
st.update_training_job(job_name, mlflow_run_id=run_id)
|
||||
|
||||
CONSOLE.print(f"[green]✓[/green] Job submitted: [bold]{job_name}[/bold]")
|
||||
if run_id:
|
||||
CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]")
|
||||
CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]")
|
||||
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
|
||||
if upload_metrics:
|
||||
_wait_and_upload_metrics(
|
||||
job_name=job_name,
|
||||
poll_interval=poll_interval,
|
||||
config_path=config,
|
||||
cfg=cfg,
|
||||
)
|
||||
else:
|
||||
CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]")
|
||||
|
||||
|
||||
@app.command()
|
||||
@@ -101,32 +232,7 @@ def status(
|
||||
raise typer.Exit(1)
|
||||
|
||||
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
||||
color = _STATUS_COLOR.get(status.status, "white")
|
||||
|
||||
CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]")
|
||||
CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]")
|
||||
if status.created:
|
||||
CONSOLE.print(f"Created: {status.created}")
|
||||
if status.model_artifacts:
|
||||
CONSOLE.print(f"Artifacts: {status.model_artifacts}")
|
||||
if status.failure_reason:
|
||||
CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]")
|
||||
|
||||
job_state = st.get_training_job(job_name)
|
||||
run_id = job_state.get("mlflow_run_id")
|
||||
already_registered = job_state.get("registered_model_version")
|
||||
if run_id and not already_registered and status.status in {"Completed", "Failed", "Stopped"}:
|
||||
version = _tracker(cfg).finalize_training_run(
|
||||
run_id=str(run_id),
|
||||
training_job_status=status,
|
||||
)
|
||||
updates = {"mlflow_finalized_status": status.status}
|
||||
if version:
|
||||
updates["registered_model_version"] = version
|
||||
st.update_training_job(job_name, **updates)
|
||||
if version:
|
||||
st.set_latest_prerelease_model_version(version)
|
||||
CONSOLE.print(f"MLflow model version: [cyan]{version}[/cyan] ([cyan]prerelease-latest[/cyan])")
|
||||
_print_training_status(status)
|
||||
|
||||
|
||||
@app.command(name="list")
|
||||
|
||||
70
src/commands/upload.py
Normal file
70
src/commands/upload.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
|
||||
|
||||
from src.aws import s3 as s3_ops
|
||||
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def upload(
|
||||
path: Path = typer.Argument(..., help="Local file or directory to upload"),
|
||||
s3_key: str | None = typer.Option(None, "--s3-key", help="S3 key for file uploads"),
|
||||
config: str = CONFIG_OPT,
|
||||
) -> None:
|
||||
"""Upload a local file or directory to S3."""
|
||||
cfg = load_cfg(config)
|
||||
|
||||
if path.is_file():
|
||||
key = s3_key or f"{cfg.s3.data_prefix.rstrip('/')}/{path.name}"
|
||||
try:
|
||||
with CONSOLE.status(f"Uploading {path.name}..."):
|
||||
uri = s3_ops.upload_file(cfg.aws.region, cfg.aws.profile, cfg.s3.bucket, str(path), key)
|
||||
except Exception as e:
|
||||
CONSOLE.print(f"[red]Upload failed: {e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
CONSOLE.print(f"[green]✓[/green] {path.name} -> {uri}")
|
||||
return
|
||||
|
||||
if path.is_dir():
|
||||
if s3_key is not None:
|
||||
CONSOLE.print("[red]--s3-key can only be used when uploading a single file.[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
files = [file for file in path.rglob("*") if file.is_file()]
|
||||
if not files:
|
||||
CONSOLE.print("[yellow]No files found in directory.[/yellow]")
|
||||
raise typer.Exit(0)
|
||||
|
||||
prefix = cfg.s3.data_prefix
|
||||
CONSOLE.print(f"Uploading {len(files)} files to s3://{cfg.s3.bucket}/{prefix.rstrip('/')}/")
|
||||
try:
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
console=CONSOLE,
|
||||
) as progress:
|
||||
task = progress.add_task("Uploading...", total=len(files))
|
||||
count = s3_ops.upload_dir(
|
||||
cfg.aws.region,
|
||||
cfg.aws.profile,
|
||||
cfg.s3.bucket,
|
||||
str(path),
|
||||
prefix,
|
||||
on_progress=lambda: progress.advance(task),
|
||||
)
|
||||
except Exception as e:
|
||||
CONSOLE.print(f"[red]Upload failed: {e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
CONSOLE.print(f"[green]✓[/green] Uploaded {count} files to s3://{cfg.s3.bucket}/{prefix.rstrip('/')}/")
|
||||
return
|
||||
|
||||
CONSOLE.print(f"[red]Path not found: {path}[/red]")
|
||||
raise typer.Exit(1)
|
||||
@@ -1,18 +1,20 @@
|
||||
from enum import Enum
|
||||
import re
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
from mypy_boto3_s3.literals import BucketLocationConstraintType
|
||||
from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from qai_hub.client import Device
|
||||
|
||||
|
||||
class MlflowMode(str, Enum):
|
||||
class MlflowMode(StrEnum):
|
||||
disabled = "disabled"
|
||||
create = "create"
|
||||
existing = "existing"
|
||||
|
||||
|
||||
class MlflowServerSize(str, Enum):
|
||||
class MlflowServerSize(StrEnum):
|
||||
small = "Small"
|
||||
medium = "Medium"
|
||||
large = "Large"
|
||||
@@ -32,6 +34,33 @@ class AwsConfig(BaseModel):
|
||||
return {"profile_name": self.profile, "region_name": self.region}
|
||||
|
||||
|
||||
DEFAULT_BOOTSTRAP_QUALIFIER = "hnb659fds"
|
||||
GENERATED_STACK_PREFIX = "qc-cli-mlops-"
|
||||
|
||||
|
||||
class InfraConfig(BaseModel):
|
||||
stack_name: str
|
||||
|
||||
@property
|
||||
def effective_bootstrap_qualifier(self) -> str:
|
||||
sanitized = re.sub(r"[^a-z0-9]", "", self.stack_name.lower())
|
||||
if not sanitized:
|
||||
return DEFAULT_BOOTSTRAP_QUALIFIER
|
||||
if self.stack_name.startswith(GENERATED_STACK_PREFIX):
|
||||
suffix = re.sub(r"[^a-z0-9]", "", self.stack_name.removeprefix(GENERATED_STACK_PREFIX).lower())
|
||||
if suffix:
|
||||
return f"q{suffix}"[:10]
|
||||
return f"q{sanitized}"[:10]
|
||||
|
||||
@property
|
||||
def effective_toolkit_stack_name(self) -> str:
|
||||
if self.stack_name.startswith(GENERATED_STACK_PREFIX):
|
||||
suffix = re.sub(r"[^A-Za-z0-9-]", "", self.stack_name.removeprefix(GENERATED_STACK_PREFIX))
|
||||
if suffix:
|
||||
return f"{self.stack_name}-bootstrap"
|
||||
return f"{self.stack_name}-bootstrap"
|
||||
|
||||
|
||||
class S3Config(BaseModel):
|
||||
bucket: str = "my-qc-mlops-bucket"
|
||||
data_prefix: str = "data/"
|
||||
@@ -48,10 +77,29 @@ class TrainingConfig(BaseModel):
|
||||
|
||||
|
||||
class SageMakerConfig(BaseModel):
|
||||
role_name: str = "qc-cli-sagemaker-role"
|
||||
role_name: str = ""
|
||||
training: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||
|
||||
|
||||
class AIHubConfig(BaseModel):
|
||||
device: Device = Field(default_factory=lambda: Device("Samsung Galaxy S25 (Family)"))
|
||||
target_runtime: str = "tflite"
|
||||
input_specs: dict[str, tuple[list[int], str]] = Field(default_factory=dict)
|
||||
job_name: str | None = None
|
||||
model_name: str | None = None
|
||||
compile_options: str | None = None
|
||||
profile_options: str | None = None
|
||||
quantize_options: str | None = None
|
||||
output_dir: str = "build/qai-hub"
|
||||
|
||||
@field_validator("device", mode="before")
|
||||
@classmethod
|
||||
def parse_device(cls, value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
return Device(value)
|
||||
return value
|
||||
|
||||
|
||||
class MlflowConfig(BaseModel):
|
||||
mode: MlflowMode = MlflowMode.disabled
|
||||
tracking_server_name: str | None = None
|
||||
@@ -66,13 +114,27 @@ class MlflowConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def require_tracking_server_name(self) -> "MlflowConfig":
|
||||
if self.mode in {MlflowMode.create, MlflowMode.existing} and not self.tracking_server_name:
|
||||
raise ValueError("mlflow.tracking_server_name is required when mlflow.mode is create or existing")
|
||||
if self.mode is MlflowMode.existing and not self.tracking_server_name:
|
||||
raise ValueError("mlflow.tracking_server_name is required when mlflow.mode is existing")
|
||||
return self
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
infra: InfraConfig
|
||||
aws: AwsConfig = Field(default_factory=AwsConfig)
|
||||
s3: S3Config = Field(default_factory=S3Config)
|
||||
sagemaker: SageMakerConfig = Field(default_factory=SageMakerConfig)
|
||||
aihub: AIHubConfig = Field(default_factory=AIHubConfig)
|
||||
mlflow: MlflowConfig = Field(default_factory=MlflowConfig)
|
||||
|
||||
@property
|
||||
def managed_mlflow_tracking_server_name(self) -> str:
|
||||
return f"{self.infra.stack_name}-mlflow"
|
||||
|
||||
@property
|
||||
def effective_mlflow_tracking_server_name(self) -> str | None:
|
||||
if self.mlflow.mode is MlflowMode.disabled:
|
||||
return None
|
||||
if self.mlflow.mode is MlflowMode.existing:
|
||||
return self.mlflow.tracking_server_name
|
||||
return self.managed_mlflow_tracking_server_name
|
||||
|
||||
@@ -5,17 +5,27 @@ from typing import Any
|
||||
|
||||
from src.infra.state import state_path, write_infra_state
|
||||
|
||||
STACK_NAME = "MLOpsStack"
|
||||
|
||||
|
||||
def bootstrap(
|
||||
*,
|
||||
profile: str,
|
||||
account_id: str,
|
||||
region: str,
|
||||
bootstrap_qualifier: str,
|
||||
toolkit_stack_name: str,
|
||||
cloudformation_execution_policy: str | None = None,
|
||||
) -> None:
|
||||
cmd = ["cdk", "bootstrap", f"aws://{account_id}/{region}", "--profile", profile]
|
||||
cmd = [
|
||||
"cdk",
|
||||
"bootstrap",
|
||||
f"aws://{account_id}/{region}",
|
||||
"--profile",
|
||||
profile,
|
||||
"--qualifier",
|
||||
bootstrap_qualifier,
|
||||
"--toolkit-stack-name",
|
||||
toolkit_stack_name,
|
||||
]
|
||||
if cloudformation_execution_policy:
|
||||
cmd.extend(["--cloudformation-execution-policies", cloudformation_execution_policy])
|
||||
_run(cmd)
|
||||
@@ -26,6 +36,9 @@ def deploy(
|
||||
profile: str,
|
||||
account_id: str,
|
||||
region: str,
|
||||
stack_name: str,
|
||||
bootstrap_qualifier: str,
|
||||
toolkit_stack_name: str,
|
||||
config_path: str,
|
||||
config_dir: str,
|
||||
config_snapshot: dict[str, Any],
|
||||
@@ -35,19 +48,24 @@ def deploy(
|
||||
"deploy",
|
||||
profile=profile,
|
||||
account_id=account_id,
|
||||
stack_name=stack_name,
|
||||
bootstrap_qualifier=bootstrap_qualifier,
|
||||
toolkit_stack_name=toolkit_stack_name,
|
||||
config_path=config_path,
|
||||
delete_bucket_data=False,
|
||||
) + ["--require-approval", "never", "--outputs-file", str(outputs_file)]
|
||||
_run(cmd)
|
||||
|
||||
outputs = _read_outputs(outputs_file)
|
||||
outputs = _read_outputs(outputs_file, stack_name)
|
||||
state = {
|
||||
"stack_name": STACK_NAME,
|
||||
"stack_name": stack_name,
|
||||
"aws": {
|
||||
"account_id": account_id,
|
||||
"region": region,
|
||||
"profile": profile,
|
||||
},
|
||||
"bootstrap_qualifier": bootstrap_qualifier,
|
||||
"toolkit_stack_name": toolkit_stack_name,
|
||||
"config": config_snapshot,
|
||||
"outputs": outputs,
|
||||
}
|
||||
@@ -59,6 +77,9 @@ def destroy(
|
||||
*,
|
||||
profile: str,
|
||||
account_id: str,
|
||||
stack_name: str,
|
||||
bootstrap_qualifier: str,
|
||||
toolkit_stack_name: str,
|
||||
config_path: str,
|
||||
delete_bucket_data: bool,
|
||||
) -> None:
|
||||
@@ -67,6 +88,9 @@ def destroy(
|
||||
"deploy",
|
||||
profile=profile,
|
||||
account_id=account_id,
|
||||
stack_name=stack_name,
|
||||
bootstrap_qualifier=bootstrap_qualifier,
|
||||
toolkit_stack_name=toolkit_stack_name,
|
||||
config_path=config_path,
|
||||
delete_bucket_data=True,
|
||||
) + ["--require-approval", "never"]
|
||||
@@ -76,6 +100,9 @@ def destroy(
|
||||
"destroy",
|
||||
profile=profile,
|
||||
account_id=account_id,
|
||||
stack_name=stack_name,
|
||||
bootstrap_qualifier=bootstrap_qualifier,
|
||||
toolkit_stack_name=toolkit_stack_name,
|
||||
config_path=config_path,
|
||||
delete_bucket_data=delete_bucket_data,
|
||||
) + ["--force"]
|
||||
@@ -87,26 +114,35 @@ def _cdk_cmd(
|
||||
*,
|
||||
profile: str,
|
||||
account_id: str,
|
||||
stack_name: str,
|
||||
bootstrap_qualifier: str,
|
||||
toolkit_stack_name: str,
|
||||
config_path: str,
|
||||
delete_bucket_data: bool,
|
||||
) -> list[str]:
|
||||
cmd = [
|
||||
"cdk",
|
||||
action,
|
||||
STACK_NAME,
|
||||
stack_name,
|
||||
"--app",
|
||||
"python app.py",
|
||||
"--profile",
|
||||
profile,
|
||||
]
|
||||
if action == "deploy":
|
||||
cmd.extend(["--toolkit-stack-name", toolkit_stack_name])
|
||||
cmd.extend([
|
||||
"-c",
|
||||
f"account_id={account_id}",
|
||||
"-c",
|
||||
f"config={config_path}",
|
||||
"-c",
|
||||
f"stack_name={STACK_NAME}",
|
||||
f"stack_name={stack_name}",
|
||||
"-c",
|
||||
f"bootstrap_qualifier={bootstrap_qualifier}",
|
||||
"-c",
|
||||
f"delete_bucket_data={str(delete_bucket_data).lower()}",
|
||||
]
|
||||
])
|
||||
return cmd
|
||||
|
||||
|
||||
@@ -119,9 +155,9 @@ def _run(cmd: list[str]) -> None:
|
||||
raise RuntimeError(f"CDK command failed with exit code {e.returncode}.") from e
|
||||
|
||||
|
||||
def _read_outputs(path: Path) -> dict[str, str]:
|
||||
def _read_outputs(path: Path, stack_name: str) -> dict[str, str]:
|
||||
if not path.exists():
|
||||
return {}
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
return data.get(STACK_NAME, {})
|
||||
return data.get(stack_name, {})
|
||||
|
||||
@@ -34,7 +34,7 @@ class QCStack(Stack):
|
||||
role = iam.CfnRole(
|
||||
self,
|
||||
"SageMakerRole",
|
||||
role_name=config.sagemaker.role_name,
|
||||
role_name=config.sagemaker.role_name or None,
|
||||
assume_role_policy_document=self._sagemaker_trust_policy(),
|
||||
managed_policy_arns=[
|
||||
f"arn:{self.partition}:iam::aws:policy/AmazonSageMakerFullAccess",
|
||||
@@ -74,6 +74,7 @@ class QCStack(Stack):
|
||||
CfnOutput(self, "SageMakerRoleArn", value=role.attr_arn)
|
||||
|
||||
if config.mlflow.mode is MlflowMode.create:
|
||||
tracking_server_name = config.managed_mlflow_tracking_server_name
|
||||
artifact_prefix = config.mlflow.artifact_prefix.strip("/")
|
||||
artifact_uri = (
|
||||
f"s3://{data_bucket.bucket_name}/{artifact_prefix}/"
|
||||
@@ -145,14 +146,14 @@ class QCStack(Stack):
|
||||
"MlflowTrackingServer",
|
||||
artifact_store_uri=artifact_uri,
|
||||
role_arn=mlflow_role.attr_arn,
|
||||
tracking_server_name=config.mlflow.tracking_server_name or "",
|
||||
tracking_server_name=tracking_server_name,
|
||||
automatic_model_registration=config.mlflow.automatic_model_registration,
|
||||
mlflow_version=config.mlflow.mlflow_version,
|
||||
tracking_server_size=config.mlflow.tracking_server_size.value,
|
||||
weekly_maintenance_window_start=config.mlflow.weekly_maintenance_window_start,
|
||||
)
|
||||
|
||||
CfnOutput(self, "MlflowTrackingServerName", value=config.mlflow.tracking_server_name or "")
|
||||
CfnOutput(self, "MlflowTrackingServerName", value=tracking_server_name)
|
||||
CfnOutput(self, "MlflowTrackingServerArn", value=tracking_server.attr_tracking_server_arn)
|
||||
CfnOutput(self, "MlflowArtifactUri", value=artifact_uri)
|
||||
CfnOutput(self, "MlflowRoleArn", value=mlflow_role.attr_arn)
|
||||
|
||||
100
src/main.py
100
src/main.py
@@ -1,104 +1,14 @@
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
import yaml
|
||||
from rich.console import Console
|
||||
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
|
||||
|
||||
from src.aws import s3 as s3_ops
|
||||
from src.commands import infra, train
|
||||
from src.commands.utils import CONFIG_OPT, load_cfg
|
||||
from src.config import Config
|
||||
from src.commands import ai_hub, infra, init, mlflow, train, upload
|
||||
|
||||
app = typer.Typer(
|
||||
help="qc-cli: End-to-end model managment for Qualcomm AI Hub.",
|
||||
no_args_is_help=True,
|
||||
)
|
||||
app.add_typer(init.app)
|
||||
app.add_typer(upload.app)
|
||||
app.add_typer(mlflow.app, name="mlflow")
|
||||
app.add_typer(infra.app, name="infra")
|
||||
app.add_typer(train.app, name="train")
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
@app.command()
|
||||
def init(
|
||||
output: str = typer.Option("config.yaml", help="Destination path for the config file"),
|
||||
force: bool = typer.Option(False, "--force", "-f", help="Overwrite an existing config file"),
|
||||
) -> None:
|
||||
"""Write a starter config.yaml to the current directory."""
|
||||
dest = Path(output)
|
||||
if dest.exists() and not force:
|
||||
console.print(f"[yellow]{dest} already exists.[/yellow] Use --force to overwrite.")
|
||||
raise typer.Exit(1)
|
||||
|
||||
config = Config()
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(dest, "w") as f:
|
||||
yaml.safe_dump(config.model_dump(mode="json"), f, sort_keys=False)
|
||||
|
||||
console.print(f"[green]✓[/green] Config written to [bold]{dest}[/bold]")
|
||||
console.print(
|
||||
"Edit it (especially [cyan]s3.bucket[/cyan] and [cyan]sagemaker.training.image_uri[/cyan]) "
|
||||
"before running other commands."
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def upload(
|
||||
path: Path = typer.Argument(..., help="Local file or directory to upload"),
|
||||
s3_key: str | None = typer.Option(None, "--s3-key", help="S3 key for file uploads"),
|
||||
config: str = CONFIG_OPT,
|
||||
) -> None:
|
||||
"""Upload a local file or directory to S3."""
|
||||
cfg = load_cfg(config)
|
||||
|
||||
if path.is_file():
|
||||
key = s3_key or f"{cfg.s3.data_prefix.rstrip('/')}/{path.name}"
|
||||
try:
|
||||
with console.status(f"Uploading {path.name}..."):
|
||||
uri = s3_ops.upload_file(cfg.aws.region, cfg.aws.profile, cfg.s3.bucket, str(path), key)
|
||||
except Exception as e:
|
||||
console.print(f"[red]Upload failed: {e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
console.print(f"[green]✓[/green] {path.name} -> {uri}")
|
||||
return
|
||||
|
||||
if path.is_dir():
|
||||
if s3_key is not None:
|
||||
console.print("[red]--s3-key can only be used when uploading a single file.[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
files = [file for file in path.rglob("*") if file.is_file()]
|
||||
if not files:
|
||||
console.print("[yellow]No files found in directory.[/yellow]")
|
||||
raise typer.Exit(0)
|
||||
|
||||
prefix = cfg.s3.data_prefix
|
||||
console.print(f"Uploading {len(files)} files to s3://{cfg.s3.bucket}/{prefix.rstrip('/')}/")
|
||||
try:
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
console=console,
|
||||
) as progress:
|
||||
task = progress.add_task("Uploading...", total=len(files))
|
||||
count = s3_ops.upload_dir(
|
||||
cfg.aws.region,
|
||||
cfg.aws.profile,
|
||||
cfg.s3.bucket,
|
||||
str(path),
|
||||
prefix,
|
||||
on_progress=lambda: progress.advance(task),
|
||||
)
|
||||
except Exception as e:
|
||||
console.print(f"[red]Upload failed: {e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
console.print(f"[green]✓[/green] Uploaded {count} files to s3://{cfg.s3.bucket}/{prefix.rstrip('/')}/")
|
||||
return
|
||||
|
||||
console.print(f"[red]Path not found: {path}[/red]")
|
||||
raise typer.Exit(1)
|
||||
app.add_typer(ai_hub.app, name="ai-hub")
|
||||
|
||||
0
src/qualcomm/__init__.py
Normal file
0
src/qualcomm/__init__.py
Normal file
114
src/qualcomm/aihub_jobs.py
Normal file
114
src/qualcomm/aihub_jobs.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import qai_hub.hub as hub
|
||||
from qai_hub.client import CompileJob, Device, InferenceJob, Model, ProfileJob, QuantizeDtype, QuantizeJob
|
||||
|
||||
|
||||
class ModelJobResult(TypedDict):
|
||||
job: CompileJob | QuantizeJob
|
||||
job_id: str
|
||||
model: Model
|
||||
model_id: str
|
||||
|
||||
|
||||
class InferenceJobResult(TypedDict):
|
||||
job: InferenceJob
|
||||
job_id: str
|
||||
outputs: Any
|
||||
|
||||
|
||||
class ProfileJobResult(TypedDict):
|
||||
job: ProfileJob
|
||||
job_id: str
|
||||
|
||||
|
||||
def _dataset_entries(inputs: dict[str, Any]) -> dict[str, list[Any]]:
|
||||
return {name: value if isinstance(value, list) else [value] for name, value in inputs.items()}
|
||||
|
||||
|
||||
def submit_compile_job(
|
||||
model: Model,
|
||||
device: Device,
|
||||
input_specs: dict[str, tuple[tuple[int, ...], str]],
|
||||
target_runtime: str,
|
||||
options: str | None = None,
|
||||
job_name: str | None = None,
|
||||
) -> ModelJobResult:
|
||||
compile_options = f"--target_runtime {target_runtime}"
|
||||
if options:
|
||||
compile_options = f"{compile_options} {options}"
|
||||
|
||||
job = hub.submit_compile_job(
|
||||
model=model,
|
||||
device=device,
|
||||
name=job_name,
|
||||
input_specs=input_specs,
|
||||
options=compile_options,
|
||||
)
|
||||
target_model = job.get_target_model()
|
||||
if target_model is None:
|
||||
raise RuntimeError(f"Compile job {job.job_id} did not produce a target model.")
|
||||
return {"job": job, "job_id": str(job.job_id), "model": target_model, "model_id": str(target_model.model_id)}
|
||||
|
||||
|
||||
def submit_inference_job(
|
||||
model: Model,
|
||||
device: Device,
|
||||
inputs: dict[str, Any],
|
||||
output_dir: str | Path,
|
||||
job_name: str | None = None,
|
||||
) -> InferenceJobResult:
|
||||
job = hub.submit_inference_job(
|
||||
model=model,
|
||||
device=device,
|
||||
inputs=_dataset_entries(inputs),
|
||||
name=job_name,
|
||||
)
|
||||
out = Path(output_dir)
|
||||
out.mkdir(parents=True, exist_ok=True)
|
||||
data = job.download_output_data(str(out))
|
||||
return {"job": job, "job_id": str(job.job_id), "outputs": data}
|
||||
|
||||
|
||||
def submit_profile_job(
|
||||
model: Model,
|
||||
device: Device,
|
||||
options: str | None = None,
|
||||
job_name: str | None = None,
|
||||
) -> ProfileJobResult:
|
||||
job = hub.submit_profile_job(
|
||||
model=model,
|
||||
device=device,
|
||||
name=job_name,
|
||||
options=options or "",
|
||||
)
|
||||
return {"job": job, "job_id": str(job.job_id)}
|
||||
|
||||
|
||||
def submit_quantize_job(
|
||||
model: Model,
|
||||
calibration_data: dict[str, Any],
|
||||
options: str | None = None,
|
||||
job_name: str | None = None,
|
||||
) -> ModelJobResult:
|
||||
job = hub.submit_quantize_job(
|
||||
model=model,
|
||||
calibration_data=_dataset_entries(calibration_data),
|
||||
weights_dtype=QuantizeDtype.INT8,
|
||||
activations_dtype=QuantizeDtype.INT8,
|
||||
name=job_name,
|
||||
options=options or "",
|
||||
)
|
||||
target_model = job.get_target_model()
|
||||
if target_model is None:
|
||||
raise RuntimeError(f"Quantize job {job.job_id} did not produce a target model.")
|
||||
return {"job": job, "job_id": str(job.job_id), "model": target_model, "model_id": str(target_model.model_id)}
|
||||
|
||||
|
||||
def download_model(model_id: str, output_path: str | Path) -> str:
|
||||
dest = Path(output_path)
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
model = hub.get_model(model_id)
|
||||
result = model.download(str(dest))
|
||||
return str(result or dest)
|
||||
83
src/qualcomm/artifacts.py
Normal file
83
src/qualcomm/artifacts.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import tarfile
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from src.aws import s3 as s3_ops
|
||||
from src.aws import sagemaker as sm_ops
|
||||
from src.config import Config
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResolvedOnnx:
|
||||
onnx_path: Path
|
||||
model_artifact: str | None
|
||||
run_name: str
|
||||
|
||||
|
||||
def _safe_extract(tar: tarfile.TarFile, dest: Path) -> None:
|
||||
dest_root = dest.resolve()
|
||||
for member in tar.getmembers():
|
||||
target = (dest / member.name).resolve()
|
||||
if dest_root != target and dest_root not in target.parents:
|
||||
raise ValueError(f"Unsafe tar member path: {member.name}")
|
||||
tar.extractall(dest, filter="data")
|
||||
|
||||
|
||||
def _find_onnx(root: Path, explicit: str | None = None) -> Path:
|
||||
if explicit:
|
||||
p = Path(explicit)
|
||||
if not p.is_absolute():
|
||||
p = root / p
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"ONNX file not found: {p}")
|
||||
return p
|
||||
|
||||
matches = sorted(root.rglob("model.onnx"))
|
||||
if not matches:
|
||||
matches = sorted(root.rglob("*.onnx"))
|
||||
if not matches:
|
||||
raise FileNotFoundError(f"No ONNX file found under {root}")
|
||||
if len(matches) > 1:
|
||||
joined = ", ".join(str(p.relative_to(root)) for p in matches)
|
||||
raise ValueError(f"Multiple ONNX files found ({joined}). Pass --onnx-path.")
|
||||
return matches[0]
|
||||
|
||||
|
||||
def resolve_onnx(
|
||||
cfg: Config,
|
||||
output_dir: str,
|
||||
from_job: str | None = None,
|
||||
model_s3_uri: str | None = None,
|
||||
onnx_path: str | None = None,
|
||||
last_training_job: str | None = None,
|
||||
) -> ResolvedOnnx:
|
||||
if onnx_path:
|
||||
path = Path(onnx_path)
|
||||
if path.exists():
|
||||
return ResolvedOnnx(onnx_path=path, model_artifact=None, run_name=path.stem)
|
||||
|
||||
job = from_job or last_training_job
|
||||
artifact = model_s3_uri
|
||||
if not artifact:
|
||||
if not job:
|
||||
raise ValueError("No model source found. Pass --onnx-path, --model-s3-uri, --from-job, or run training first.")
|
||||
artifact = sm_ops.get_model_artifacts(cfg.aws.region, cfg.aws.profile, job)
|
||||
|
||||
run_name = job or Path(artifact).name.removesuffix(".tar.gz").replace("/", "-")
|
||||
root = Path(output_dir) / run_name / "source"
|
||||
tar_path = root / "model.tar.gz"
|
||||
s3_ops.download_file(cfg.aws.region, cfg.aws.profile, artifact, str(tar_path))
|
||||
|
||||
extract_dir = root / "extracted"
|
||||
extract_dir.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
with tarfile.open(tar_path, "r:gz") as tar:
|
||||
_safe_extract(tar, extract_dir)
|
||||
except tarfile.TarError as e:
|
||||
raise ValueError(f"Invalid model tarball: {tar_path}") from e
|
||||
|
||||
return ResolvedOnnx(
|
||||
onnx_path=_find_onnx(extract_dir, onnx_path),
|
||||
model_artifact=artifact,
|
||||
run_name=run_name,
|
||||
)
|
||||
24
src/state.py
24
src/state.py
@@ -33,6 +33,26 @@ class CliStateStore:
|
||||
value = self.get("last_training_job")
|
||||
return str(value) if value else None
|
||||
|
||||
def get_last_model_artifact(self) -> str | None:
|
||||
value = self.get("last_model_artifact")
|
||||
return str(value) if value else None
|
||||
|
||||
def get_last_optimized_model_id(self) -> str | None:
|
||||
value = self.get("last_optimized_model_id")
|
||||
return str(value) if value else None
|
||||
|
||||
def get_last_quantized_model_id(self) -> str | None:
|
||||
value = self.get("last_quantized_model_id")
|
||||
return str(value) if value else None
|
||||
|
||||
def get_last_compiled_model_id(self) -> str | None:
|
||||
value = self.get("last_compiled_model_id")
|
||||
return str(value) if value else None
|
||||
|
||||
def get_last_downloaded_model(self) -> str | None:
|
||||
value = self.get("last_downloaded_model")
|
||||
return str(value) if value else None
|
||||
|
||||
def set_last_training_job(self, job_name: str) -> None:
|
||||
self.update(last_training_job=job_name)
|
||||
|
||||
@@ -48,8 +68,8 @@ class CliStateStore:
|
||||
state["training_jobs"] = jobs
|
||||
self._write(state)
|
||||
|
||||
def set_latest_prerelease_model_version(self, version: str) -> None:
|
||||
self.update(latest_prerelease_model_version=version)
|
||||
def set_latest_experiment_model_version(self, version: str) -> None:
|
||||
self.update(latest_experiment_model_version=version)
|
||||
|
||||
def _write(self, state: dict[str, Any]) -> None:
|
||||
with open(self.path, "w") as f:
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from src.tracking.mlflow import MlflowTracker, NoopTracker, Tracker
|
||||
from src.tracking.mlflow import FinalizeResult, MlflowTracker, NoopTracker, Tracker
|
||||
|
||||
__all__ = ["MlflowTracker", "NoopTracker", "Tracker"]
|
||||
__all__ = ["FinalizeResult", "MlflowTracker", "NoopTracker", "Tracker"]
|
||||
|
||||
93
src/tracking/metrics.py
Normal file
93
src/tracking/metrics.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import json
|
||||
import math
|
||||
import tarfile
|
||||
from dataclasses import dataclass
|
||||
from pathlib import PurePosixPath
|
||||
from typing import Any
|
||||
|
||||
METRICS_ARTIFACT_NAME = "training_metrics.json"
|
||||
METRICS_SCHEMA_VERSION = 1
|
||||
MAX_METRICS_ARTIFACT_BYTES = 10 * 1024 * 1024
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MetricStep:
|
||||
step: int
|
||||
metrics: dict[str, float]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TrainingMetrics:
|
||||
steps: list[MetricStep]
|
||||
summary: dict[str, float]
|
||||
raw: dict[str, Any]
|
||||
|
||||
|
||||
def parse_training_metrics(data: bytes) -> TrainingMetrics:
|
||||
try:
|
||||
value = json.loads(data)
|
||||
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
|
||||
raise ValueError(f"{METRICS_ARTIFACT_NAME} is not valid JSON") from exc
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError(f"{METRICS_ARTIFACT_NAME} must contain a JSON object")
|
||||
if value.get("schema_version") != METRICS_SCHEMA_VERSION:
|
||||
raise ValueError(f"Unsupported training metrics schema version: {value.get('schema_version')!r}")
|
||||
|
||||
raw_steps = value.get("steps")
|
||||
if not isinstance(raw_steps, list):
|
||||
raise ValueError("training metrics 'steps' must be a list")
|
||||
|
||||
steps: list[MetricStep] = []
|
||||
previous_step: int | None = None
|
||||
for index, raw_step in enumerate(raw_steps):
|
||||
if not isinstance(raw_step, dict):
|
||||
raise ValueError(f"training metrics step {index} must be an object")
|
||||
step = raw_step.get("step")
|
||||
if isinstance(step, bool) or not isinstance(step, int) or step < 0:
|
||||
raise ValueError(f"training metrics step {index} has an invalid 'step'")
|
||||
if previous_step is not None and step <= previous_step:
|
||||
raise ValueError("training metrics steps must be unique and strictly increasing")
|
||||
metrics = _numeric_metrics(raw_step.get("metrics"), f"training metrics step {step}")
|
||||
steps.append(MetricStep(step=step, metrics=metrics))
|
||||
previous_step = step
|
||||
|
||||
summary = _numeric_metrics(value.get("summary", {}), "training metrics summary")
|
||||
return TrainingMetrics(steps=steps, summary=summary, raw=value)
|
||||
|
||||
|
||||
def read_training_metrics_from_tar(archive_path: str) -> bytes | None:
|
||||
with tarfile.open(archive_path, mode="r:*") as archive:
|
||||
matches = [
|
||||
member
|
||||
for member in archive.getmembers()
|
||||
if member.isfile() and PurePosixPath(member.name).name == METRICS_ARTIFACT_NAME
|
||||
]
|
||||
if not matches:
|
||||
return None
|
||||
if len(matches) > 1:
|
||||
raise ValueError(f"Model archive contains multiple {METRICS_ARTIFACT_NAME} files")
|
||||
if matches[0].size > MAX_METRICS_ARTIFACT_BYTES:
|
||||
raise ValueError(
|
||||
f"{METRICS_ARTIFACT_NAME} exceeds the {MAX_METRICS_ARTIFACT_BYTES}-byte size limit"
|
||||
)
|
||||
extracted = archive.extractfile(matches[0])
|
||||
if extracted is None:
|
||||
raise ValueError(f"Could not read {METRICS_ARTIFACT_NAME} from model archive")
|
||||
return extracted.read()
|
||||
|
||||
|
||||
def _numeric_metrics(value: Any, context: str) -> dict[str, float]:
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError(f"{context} 'metrics' must be an object")
|
||||
|
||||
metrics: dict[str, float] = {}
|
||||
for raw_name, raw_value in value.items():
|
||||
if not isinstance(raw_name, str) or not raw_name:
|
||||
raise ValueError(f"{context} contains an invalid metric name")
|
||||
if isinstance(raw_value, bool) or not isinstance(raw_value, int | float):
|
||||
raise ValueError(f"{context} metric '{raw_name}' must be numeric")
|
||||
metric_value = float(raw_value)
|
||||
if not math.isfinite(metric_value):
|
||||
raise ValueError(f"{context} metric '{raw_name}' must be finite")
|
||||
metrics[raw_name] = metric_value
|
||||
return metrics
|
||||
@@ -1,10 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
|
||||
from src.aws import mlflow as aws_mlflow
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
from src.aws import s3
|
||||
from src.cloud.mlflow import MlflowTrackingBackend, mlflow_tracking_backend_from_config
|
||||
from src.config import Config, MlflowMode
|
||||
from src.tracking.metrics import METRICS_ARTIFACT_NAME, parse_training_metrics, read_training_metrics_from_tar
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FinalizeResult:
|
||||
registered_model_version: str | None = None
|
||||
warnings: tuple[str, ...] = ()
|
||||
|
||||
|
||||
class Tracker(Protocol):
|
||||
@@ -15,7 +26,21 @@ class Tracker(Protocol):
|
||||
*,
|
||||
run_id: str | None,
|
||||
training_job_status: Any,
|
||||
) -> str | None: ...
|
||||
region: str,
|
||||
profile: str,
|
||||
command: str,
|
||||
) -> FinalizeResult: ...
|
||||
|
||||
def ensure_training_run(self, job_name: str) -> str: ...
|
||||
|
||||
def upload_training_metrics(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
training_job_status: Any,
|
||||
region: str,
|
||||
profile: str,
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -23,124 +48,187 @@ class NoopTracker:
|
||||
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
|
||||
return None
|
||||
|
||||
def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None:
|
||||
return None
|
||||
def finalize_training_run(
|
||||
self,
|
||||
*,
|
||||
run_id: str | None,
|
||||
training_job_status: Any,
|
||||
region: str,
|
||||
profile: str,
|
||||
command: str,
|
||||
) -> FinalizeResult:
|
||||
return FinalizeResult()
|
||||
|
||||
def ensure_training_run(self, job_name: str) -> str:
|
||||
raise RuntimeError("MLflow is disabled.")
|
||||
|
||||
def upload_training_metrics(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
training_job_status: Any,
|
||||
region: str,
|
||||
profile: str,
|
||||
) -> bool:
|
||||
raise RuntimeError("MLflow is disabled.")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MlflowTracker:
|
||||
mlflow: Any
|
||||
tracking_uri: str
|
||||
experiment_name: str
|
||||
registered_model_name: str
|
||||
register_trained_models: bool
|
||||
tracking_backend: MlflowTrackingBackend
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: Config) -> Tracker:
|
||||
if cfg.mlflow.mode is MlflowMode.disabled:
|
||||
return NoopTracker()
|
||||
|
||||
try:
|
||||
import mlflow
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"MLflow is enabled in config but optional dependencies are not installed. "
|
||||
"Install with: qc-cli[mlflow]"
|
||||
) from e
|
||||
os.environ.setdefault("MLFLOW_SUPPRESS_PRINTING_URL_TO_STDOUT", "true")
|
||||
|
||||
if not cfg.mlflow.tracking_server_name:
|
||||
raise RuntimeError("mlflow.tracking_server_name is required when MLflow is enabled.")
|
||||
tracking_server_name = cfg.effective_mlflow_tracking_server_name
|
||||
if not tracking_server_name:
|
||||
raise RuntimeError("MLflow tracking server name could not be resolved.")
|
||||
|
||||
tracking_uri = aws_mlflow.get_tracking_server_arn(
|
||||
cfg.aws.region,
|
||||
cfg.aws.profile,
|
||||
cfg.mlflow.tracking_server_name,
|
||||
)
|
||||
mlflow.set_tracking_uri(tracking_uri)
|
||||
mlflow.set_experiment(cfg.mlflow.experiment_name)
|
||||
tracking_backend = mlflow_tracking_backend_from_config(cfg)
|
||||
|
||||
tracking_uri = tracking_backend.get_tracking_uri(tracking_server_name)
|
||||
with tracking_backend.auth_env():
|
||||
mlflow.set_tracking_uri(tracking_uri)
|
||||
mlflow.set_experiment(cfg.mlflow.experiment_name)
|
||||
|
||||
return cls(
|
||||
mlflow=mlflow,
|
||||
tracking_uri=tracking_uri,
|
||||
experiment_name=cfg.mlflow.experiment_name,
|
||||
registered_model_name=cfg.mlflow.registered_model_name,
|
||||
register_trained_models=cfg.mlflow.register_trained_models,
|
||||
tracking_backend=tracking_backend,
|
||||
)
|
||||
|
||||
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
|
||||
run = self.mlflow.start_run(run_name=training_job.job_name)
|
||||
run_id = str(run.info.run_id)
|
||||
with self.tracking_backend.auth_env():
|
||||
with mlflow.start_run(run_name=training_job.job_name) as run:
|
||||
run_id = str(run.info.run_id)
|
||||
self._log_params(
|
||||
self.tracking_backend.training_run_params(
|
||||
training_job,
|
||||
region=region,
|
||||
profile=profile,
|
||||
role_arn=role_arn,
|
||||
)
|
||||
)
|
||||
self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()})
|
||||
mlflow.set_tags(
|
||||
{
|
||||
"qc_cli.stage": "experiment",
|
||||
"qc_cli.artifact_kind": "trained_source",
|
||||
"qc_cli.source": self.tracking_backend.provider_name,
|
||||
"qc_cli.command": "train start",
|
||||
**self.tracking_backend.training_run_tags(training_job),
|
||||
}
|
||||
)
|
||||
return run_id
|
||||
|
||||
params = {
|
||||
"aws.region": region,
|
||||
"aws.profile": profile,
|
||||
"sagemaker.role_arn": role_arn,
|
||||
"sagemaker.job_name": training_job.job_name,
|
||||
"sagemaker.training_image": training_job.image_uri,
|
||||
"sagemaker.instance_type": training_job.instance_type,
|
||||
"sagemaker.instance_count": training_job.instance_count,
|
||||
"sagemaker.s3_train_uri": training_job.s3_train_uri,
|
||||
"sagemaker.s3_output_path": training_job.s3_output_path,
|
||||
"sagemaker.entry_point": training_job.entry_point,
|
||||
"sagemaker.source_dir": training_job.source_dir,
|
||||
}
|
||||
self._log_params(params)
|
||||
self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()})
|
||||
self.mlflow.set_tags(
|
||||
{
|
||||
"qc_cli.stage": "prerelease",
|
||||
"qc_cli.command": "train start",
|
||||
"sagemaker.job_name": training_job.job_name,
|
||||
}
|
||||
)
|
||||
self.mlflow.end_run()
|
||||
return run_id
|
||||
|
||||
def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None:
|
||||
def finalize_training_run(
|
||||
self,
|
||||
*,
|
||||
run_id: str | None,
|
||||
training_job_status: Any,
|
||||
region: str,
|
||||
profile: str,
|
||||
command: str,
|
||||
) -> FinalizeResult:
|
||||
if not run_id:
|
||||
return None
|
||||
return FinalizeResult()
|
||||
|
||||
with self.mlflow.start_run(run_id=run_id):
|
||||
self._log_params(
|
||||
{
|
||||
"sagemaker.training_status": training_job_status.status,
|
||||
"sagemaker.created_at": training_job_status.created,
|
||||
"sagemaker.modified_at": training_job_status.modified,
|
||||
"sagemaker.model_artifacts": training_job_status.model_artifacts,
|
||||
"sagemaker.failure_reason": training_job_status.failure_reason,
|
||||
}
|
||||
)
|
||||
self._log_final_metrics(training_job_status.raw)
|
||||
self.mlflow.set_tag("qc_cli.command", "train status")
|
||||
with self.tracking_backend.auth_env():
|
||||
with mlflow.start_run(run_id=run_id):
|
||||
self._log_params(self.tracking_backend.training_status_params(training_job_status))
|
||||
self._log_final_metrics(training_job_status.raw)
|
||||
mlflow.set_tag("qc_cli.command", command)
|
||||
|
||||
if training_job_status.status != "Completed" or not training_job_status.model_artifacts:
|
||||
self.mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status)
|
||||
return None
|
||||
if training_job_status.status != "Completed" or not training_job_status.model_artifacts:
|
||||
mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status)
|
||||
return FinalizeResult()
|
||||
|
||||
if not self.register_trained_models:
|
||||
return None
|
||||
if not self.register_trained_models:
|
||||
return FinalizeResult()
|
||||
|
||||
client = self.mlflow.tracking.MlflowClient()
|
||||
self._ensure_registered_model(client, self.registered_model_name)
|
||||
version = client.create_model_version(
|
||||
name=self.registered_model_name,
|
||||
source=training_job_status.model_artifacts,
|
||||
run_id=run_id,
|
||||
client = MlflowClient()
|
||||
self._ensure_registered_model(client, self.registered_model_name)
|
||||
version = client.create_model_version(
|
||||
name=self.registered_model_name,
|
||||
source=training_job_status.model_artifacts,
|
||||
run_id=run_id,
|
||||
tags={
|
||||
"qc_cli.stage": "experiment",
|
||||
"qc_cli.artifact_kind": "trained_source",
|
||||
"qc_cli.source": self.tracking_backend.provider_name,
|
||||
**self.tracking_backend.model_version_tags(training_job_status),
|
||||
},
|
||||
)
|
||||
version_number = str(version.version)
|
||||
client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number)
|
||||
mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name)
|
||||
mlflow.set_tag("qc_cli.registered_model_version", version_number)
|
||||
return FinalizeResult(registered_model_version=version_number)
|
||||
|
||||
def ensure_training_run(self, job_name: str) -> str:
|
||||
with self.tracking_backend.auth_env():
|
||||
client = MlflowClient()
|
||||
experiment = client.get_experiment_by_name(self.experiment_name)
|
||||
if experiment is None:
|
||||
experiment_id = mlflow.create_experiment(self.experiment_name)
|
||||
else:
|
||||
experiment_id = experiment.experiment_id
|
||||
|
||||
for run in client.search_runs([experiment_id], max_results=1000):
|
||||
if run.data.tags.get("sagemaker.job_name") == job_name:
|
||||
return str(run.info.run_id)
|
||||
|
||||
run = client.create_run(
|
||||
experiment_id,
|
||||
run_name=job_name,
|
||||
tags={
|
||||
"qc_cli.stage": "prerelease",
|
||||
"sagemaker.job_name": training_job_status.name,
|
||||
"qc_cli.stage": "experiment",
|
||||
"qc_cli.artifact_kind": "trained_source",
|
||||
"qc_cli.source": self.tracking_backend.provider_name,
|
||||
"qc_cli.command": "mlflow upload-metrics",
|
||||
"sagemaker.job_name": job_name,
|
||||
},
|
||||
)
|
||||
version_number = str(version.version)
|
||||
self._set_alias(client, self.registered_model_name, "prerelease-latest", version_number)
|
||||
self.mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name)
|
||||
self.mlflow.set_tag("qc_cli.registered_model_version", version_number)
|
||||
return version_number
|
||||
return str(run.info.run_id)
|
||||
|
||||
def upload_training_metrics(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
training_job_status: Any,
|
||||
region: str,
|
||||
profile: str,
|
||||
) -> bool:
|
||||
if not training_job_status.model_artifacts:
|
||||
raise ValueError(f"Training job '{training_job_status.name}' has no model artifacts.")
|
||||
|
||||
with self.tracking_backend.auth_env():
|
||||
with mlflow.start_run(run_id=run_id):
|
||||
self._log_params(self.tracking_backend.training_status_params(training_job_status))
|
||||
self._log_final_metrics(training_job_status.raw)
|
||||
history_uploaded = self._log_training_metrics(
|
||||
training_job_status.model_artifacts,
|
||||
region=region,
|
||||
profile=profile,
|
||||
)
|
||||
mlflow.set_tag("qc_cli.command", "mlflow upload-metrics")
|
||||
mlflow.set_tag("qc_cli.metrics_history_uploaded", str(history_uploaded).lower())
|
||||
return history_uploaded
|
||||
|
||||
def _log_params(self, params: dict[str, Any]) -> None:
|
||||
cleaned = {key: str(value) for key, value in params.items() if value is not None}
|
||||
if cleaned:
|
||||
self.mlflow.log_params(cleaned)
|
||||
mlflow.log_params(cleaned)
|
||||
|
||||
def _log_final_metrics(self, training_job: dict[str, Any]) -> None:
|
||||
metrics = {}
|
||||
@@ -150,14 +238,30 @@ class MlflowTracker:
|
||||
if name and value is not None:
|
||||
metrics[str(name)] = float(value)
|
||||
if metrics:
|
||||
self.mlflow.log_metrics(metrics)
|
||||
mlflow.log_metrics(metrics)
|
||||
|
||||
def _ensure_registered_model(self, client: Any, name: str) -> None:
|
||||
def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> bool:
|
||||
with tempfile.TemporaryDirectory(prefix="qc-cli-metrics-") as temp_dir:
|
||||
archive_path = s3.download_file(
|
||||
region,
|
||||
profile,
|
||||
model_artifacts,
|
||||
os.path.join(temp_dir, "model.tar.gz"),
|
||||
)
|
||||
metrics_data = read_training_metrics_from_tar(archive_path)
|
||||
if metrics_data is None:
|
||||
return False
|
||||
metrics = parse_training_metrics(metrics_data)
|
||||
for metric_step in metrics.steps:
|
||||
if metric_step.metrics:
|
||||
mlflow.log_metrics(metric_step.metrics, step=metric_step.step)
|
||||
if metrics.summary:
|
||||
mlflow.log_metrics(metrics.summary)
|
||||
mlflow.log_dict(metrics.raw, METRICS_ARTIFACT_NAME)
|
||||
return True
|
||||
|
||||
def _ensure_registered_model(self, client: MlflowClient, name: str) -> None:
|
||||
try:
|
||||
client.get_registered_model(name)
|
||||
except Exception:
|
||||
client.create_registered_model(name)
|
||||
|
||||
def _set_alias(self, client: Any, name: str, alias: str, version: str) -> None:
|
||||
if hasattr(client, "set_registered_model_alias"):
|
||||
client.set_registered_model_alias(name, alias, version)
|
||||
|
||||
75
src/tracking/upload.py
Normal file
75
src/tracking/upload.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from src import state as state_ops
|
||||
from src.aws import sagemaker as sm_ops
|
||||
from src.config import Config, MlflowMode
|
||||
from src.tracking.mlflow import MlflowTracker
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MetricsUploadResult:
|
||||
run_id: str
|
||||
registered_model_version: str | None = None
|
||||
metrics_history_uploaded: bool = True
|
||||
|
||||
|
||||
def upload_training_metrics(
|
||||
*,
|
||||
job_name: str,
|
||||
config_path: str,
|
||||
cfg: Config,
|
||||
force: bool = False,
|
||||
) -> MetricsUploadResult:
|
||||
if cfg.mlflow.mode is MlflowMode.disabled:
|
||||
raise RuntimeError("MLflow is disabled in config.yaml.")
|
||||
|
||||
st = state_ops.store(config_path)
|
||||
job_state = st.get_training_job(job_name)
|
||||
if job_state.get("mlflow_metrics_uploaded") and not force:
|
||||
return MetricsUploadResult(
|
||||
run_id=str(job_state.get("mlflow_run_id") or ""),
|
||||
registered_model_version=(
|
||||
str(job_state["registered_model_version"])
|
||||
if job_state.get("registered_model_version")
|
||||
else None
|
||||
),
|
||||
metrics_history_uploaded=bool(job_state.get("mlflow_metrics_history_uploaded", True)),
|
||||
)
|
||||
|
||||
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
||||
if status.status != "Completed":
|
||||
raise RuntimeError(
|
||||
f"Training job {job_name} is {status.status}; metrics can be uploaded only after completion."
|
||||
)
|
||||
|
||||
tracker = MlflowTracker.from_config(cfg)
|
||||
run_id = str(job_state.get("mlflow_run_id") or tracker.ensure_training_run(job_name))
|
||||
st.update_training_job(job_name, mlflow_run_id=run_id)
|
||||
metrics_history_uploaded = tracker.upload_training_metrics(
|
||||
run_id=run_id,
|
||||
training_job_status=status,
|
||||
region=cfg.aws.region,
|
||||
profile=cfg.aws.profile,
|
||||
)
|
||||
finalized = tracker.finalize_training_run(
|
||||
run_id=run_id,
|
||||
training_job_status=status,
|
||||
region=cfg.aws.region,
|
||||
profile=cfg.aws.profile,
|
||||
command="mlflow upload-metrics",
|
||||
)
|
||||
updates = {
|
||||
"mlflow_metrics_uploaded": True,
|
||||
"mlflow_metrics_history_uploaded": metrics_history_uploaded,
|
||||
"mlflow_finalized_status": status.status,
|
||||
}
|
||||
if finalized.registered_model_version:
|
||||
updates["registered_model_version"] = finalized.registered_model_version
|
||||
st.update_training_job(job_name, **updates)
|
||||
if finalized.registered_model_version:
|
||||
st.set_latest_experiment_model_version(finalized.registered_model_version)
|
||||
return MetricsUploadResult(
|
||||
run_id=run_id,
|
||||
registered_model_version=finalized.registered_model_version,
|
||||
metrics_history_uploaded=metrics_history_uploaded,
|
||||
)
|
||||
182
uv.lock
generated
182
uv.lock
generated
@@ -210,6 +210,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/06/7c/1e7964f0f267301bb5026fed45369961f7311073412bcd36e09fbe4df0de/aws_cdk_lib-2.253.1-py3-none-any.whl", hash = "sha256:03a6f5080978f9e3576f490d06fbd1f41f159280d34dbca50721de4a19694136", size = 50271288, upload-time = "2026-05-08T16:04:41.956Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "backoff"
|
||||
version = "2.2.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/47/d7/5bbeb12c44d7c4f2fb5b56abce497eb5ed9f34d85701de869acedd602619/backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba", size = 17001, upload-time = "2022-10-05T19:19:32.061Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/df/73/b6e24bd22e6720ca8ee9a85a0c4a2971af8497d8f3193fa05390cbd46e09/backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8", size = 15148, upload-time = "2022-10-05T19:19:30.546Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "blinker"
|
||||
version = "1.9.0"
|
||||
@@ -591,6 +600,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e3/43/33806117fc8e0992aae890be73990b31d802b66e8a423bf87b80990fce66/databricks_sdk-0.111.0-py3-none-any.whl", hash = "sha256:d14ba186afd2bea03c7157d2f03e0f861a0b8eff528cfdba926d07b9e20384b8", size = 901536, upload-time = "2026-05-25T09:29:58.057Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deprecation"
|
||||
version = "2.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "packaging" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/5a/d3/8ae2869247df154b64c1884d7346d412fed0c49df84db635aab2d1c40e62/deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff", size = 173788, upload-time = "2020-04-20T14:23:38.738Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/02/c3/253a89ee03fc9b9682f1541728eb66db7db22148cd94f89ab22528cd1e1b/deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a", size = 11178, upload-time = "2020-04-20T14:23:36.581Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "docker"
|
||||
version = "7.1.0"
|
||||
@@ -908,6 +929,41 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h5py"
|
||||
version = "3.16.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "numpy" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/db/33/acd0ce6863b6c0d7735007df01815403f5589a21ff8c2e1ee2587a38f548/h5py-3.16.0.tar.gz", hash = "sha256:a0dbaad796840ccaa67a4c144a0d0c8080073c34c76d5a6941d6818678ef2738", size = 446526, upload-time = "2026-03-06T13:49:08.07Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/0f/9e/6142ebfda0cb6e9349c091eae73c2e01a770b7659255248d637bec54a88b/h5py-3.16.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:370a845f432c2c9619db8eed334d1e610c6015796122b0e57aa46312c22617d9", size = 3671808, upload-time = "2026-03-06T13:48:19.737Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b0/65/5e088a45d0f43cd814bc5bec521c051d42005a472e804b1a36c48dada09b/h5py-3.16.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:42108e93326c50c2810025aade9eac9d6827524cdccc7d4b75a546e5ab308edb", size = 3045837, upload-time = "2026-03-06T13:48:21.854Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/da/1e/6172269e18cc5a484e2913ced33339aad588e02ba407fafd00d369e22ef3/h5py-3.16.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:099f2525c9dcf28de366970a5fb34879aab20491589fa89ce2863a84218bb524", size = 5193860, upload-time = "2026-03-06T13:48:24.071Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bd/98/ef2b6fe2903e377cbe870c3b2800d62552f1e3dbe81ce49e1923c53d1c5c/h5py-3.16.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:9300ad32dea9dfc5171f94d5f6948e159ed93e4701280b0f508773b3f582f402", size = 5400417, upload-time = "2026-03-06T13:48:25.728Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bc/81/5b62d760039eed64348c98129d17061fdfc7839fc9c04eaaad6dee1004e4/h5py-3.16.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:171038f23bccddfc23f344cadabdfc9917ff554db6a0d417180d2747fe4c75a7", size = 5185214, upload-time = "2026-03-06T13:48:27.436Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/28/c4/532123bcd9080e250696779c927f2cb906c8bf3447df98f5ceb8dcded539/h5py-3.16.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:7e420b539fb6023a259a1b14d4c9f6df8cf50d7268f48e161169987a57b737ff", size = 5414598, upload-time = "2026-03-06T13:48:29.49Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c3/d9/a27997f84341fc0dfcdd1fe4179b6ba6c32a7aa880fdb8c514d4dad6fba3/h5py-3.16.0-cp313-cp313-win_amd64.whl", hash = "sha256:18f2bbcd545e6991412253b98727374c356d67caa920e68dc79eab36bf5fedad", size = 3175509, upload-time = "2026-03-06T13:48:31.131Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a5/23/bb8647521d4fd770c30a76cfc6cb6a2f5495868904054e92f2394c5a78ff/h5py-3.16.0-cp313-cp313-win_arm64.whl", hash = "sha256:656f00e4d903199a1d58df06b711cf3ca632b874b4207b7dbec86185b5c8c7d4", size = 2647362, upload-time = "2026-03-06T13:48:33.411Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/48/3c/7fcd9b4c9eed82e91fb15568992561019ae7a829d1f696b2c844355d95dd/h5py-3.16.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:9c9d307c0ef862d1cd5714f72ecfafe0a5d7529c44845afa8de9f46e5ba8bd65", size = 3678608, upload-time = "2026-03-06T13:48:35.183Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6a/b7/9366ed44ced9b7ef357ab48c94205280276db9d7f064aa3012a97227e966/h5py-3.16.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:8c1eff849cdd53cbc73c214c30ebdb6f1bb8b64790b4b4fc36acdb5e43570210", size = 3054773, upload-time = "2026-03-06T13:48:37.139Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/58/a5/4964bc0e91e86340c2bbda83420225b2f770dcf1eb8a39464871ad769436/h5py-3.16.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:e2c04d129f180019e216ee5f9c40b78a418634091c8782e1f723a6ca3658b965", size = 5198886, upload-time = "2026-03-06T13:48:38.879Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f1/16/d905e7f53e661ce2c24686c38048d8e2b750ffc4350009d41c4e6c6c9826/h5py-3.16.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:e4360f15875a532bc7b98196c7592ed4fc92672a57c0a621355961cafb17a6dd", size = 5404883, upload-time = "2026-03-06T13:48:41.324Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4b/f2/58f34cb74af46d39f4cd18ea20909a8514960c5a3e5b92fd06a28161e0a8/h5py-3.16.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:3fae9197390c325e62e0a1aa977f2f62d994aa87aab182abbea85479b791197c", size = 5192039, upload-time = "2026-03-06T13:48:43.117Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ce/ca/934a39c24ce2e2db017268c08da0537c20fa0be7e1549be3e977313fc8f5/h5py-3.16.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:43259303989ac8adacc9986695b31e35dba6fd1e297ff9c6a04b7da5542139cc", size = 5421526, upload-time = "2026-03-06T13:48:44.838Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3e/14/615a450205e1b56d16c6783f5ccd116cde05550faad70ae077c955654a75/h5py-3.16.0-cp314-cp314-win_amd64.whl", hash = "sha256:fa48993a0b799737ba7fd21e2350fa0a60701e58180fae9f2de834bc39a147ab", size = 3183263, upload-time = "2026-03-06T13:48:47.117Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7b/48/a6faef5ed632cae0c65ac6b214a6614a0b510c3183532c521bdb0055e117/h5py-3.16.0-cp314-cp314-win_arm64.whl", hash = "sha256:1897a771a7f40d05c262fc8f37376ec37873218544b70216872876c627640f63", size = 2663450, upload-time = "2026-03-06T13:48:48.707Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5d/32/0c8bb8aedb62c772cf7c1d427c7d1951477e8c2835f872bc0a13d1f85f86/h5py-3.16.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:15922e485844f77c0b9d275396d435db3baa58292a9c2176a386e072e0cf2491", size = 3760693, upload-time = "2026-03-06T13:48:50.453Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1d/1f/fcc5977d32d6387c5c9a694afee716a5e20658ac08b3ff24fdec79fb05f2/h5py-3.16.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:df02dd29bd247f98674634dfe41f89fd7c16ba3d7de8695ec958f58404a4e618", size = 3181305, upload-time = "2026-03-06T13:48:52.221Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f5/a1/af87f64b9f986889884243643621ebbd4ac72472ba8ec8cec891ac8e2ca1/h5py-3.16.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:0f456f556e4e2cebeebd9d66adf8dc321770a42593494a0b6f0af54a7567b242", size = 5074061, upload-time = "2026-03-06T13:48:54.089Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cc/d0/146f5eaff3dc246a9c7f6e5e4f42bd45cc613bce16693bcd4d1f7c958bf5/h5py-3.16.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:3e6cb3387c756de6a9492d601553dffea3fe11b5f22b443aac708c69f3f55e16", size = 5279216, upload-time = "2026-03-06T13:48:56.75Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a1/9d/12a13424f1e604fc7df9497b73c0356fb78c2fb206abd7465ce47226e8fd/h5py-3.16.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:8389e13a1fd745ad2856873e8187fd10268b2d9677877bb667b41aebd771d8b7", size = 5070068, upload-time = "2026-03-06T13:48:59.169Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/41/8c/bbe98f813722b4873818a8db3e15aa3e625b59278566905ac439725e8070/h5py-3.16.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:346df559a0f7dcb31cf8e44805319e2ab24b8957c45e7708ce503b2ec79ba725", size = 5300253, upload-time = "2026-03-06T13:49:02.033Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/32/9e/87e6705b4d6890e7cecdf876e2a7d3e40654a2ae37482d79a6f1b87f7b92/h5py-3.16.0-cp314-cp314t-win_amd64.whl", hash = "sha256:4c6ab014ab704b4feaa719ae783b86522ed0bf1f82184704ed3c9e4e3228796e", size = 3381671, upload-time = "2026-03-06T13:49:04.351Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/96/91/9fad90cfc5f9b2489c7c26ad897157bce82f0e9534a986a221b99760b23b/h5py-3.16.0-cp314-cp314t-win_arm64.whl", hash = "sha256:faca8fb4e4319c09d83337adc80b2ca7d5c5a343c2d6f1b6388f32cfecca13c1", size = 2740706, upload-time = "2026-03-06T13:49:06.347Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "huey"
|
||||
version = "2.6.0"
|
||||
@@ -947,15 +1003,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8a/db/55a262f3606bebcae07cc14095338471ad7c0bbcaa37707e6f0ee49725b7/importlib_resources-7.1.0-py3-none-any.whl", hash = "sha256:1bd7b48b4088eddb2cd16382150bb515af0bd2c70128194392725f82ad2c96a1", size = 37232, upload-time = "2026-04-12T16:36:08.219Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iniconfig"
|
||||
version = "2.3.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itsdangerous"
|
||||
version = "2.2.0"
|
||||
@@ -1618,15 +1665,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ff/6e/cf826fae916b8658848d7b9f38d88da6396895c676e8086fc0988073aaf8/pillow-12.2.0-cp314-cp314t-win_arm64.whl", hash = "sha256:aa88ccfe4e32d362816319ed727a004423aab09c5cea43c01a4b435643fa34eb", size = 2556579, upload-time = "2026-04-01T14:45:52.529Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pluggy"
|
||||
version = "1.6.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prettytable"
|
||||
version = "3.17.0"
|
||||
@@ -1718,17 +1756,16 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "protobuf"
|
||||
version = "6.33.6"
|
||||
version = "6.31.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/66/70/e908e9c5e52ef7c3a6c7902c9dfbb34c7e29c25d2f81ade3856445fd5c94/protobuf-6.33.6.tar.gz", hash = "sha256:a6768d25248312c297558af96a9f9c929e8c4cee0659cb07e780731095f38135", size = 444531, upload-time = "2026-03-18T19:05:00.988Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/52/f3/b9655a711b32c19720253f6f06326faf90580834e2e83f840472d752bc8b/protobuf-6.31.1.tar.gz", hash = "sha256:d8cac4c982f0b957a4dc73a80e2ea24fab08e679c0de9deb835f4a12d69aca9a", size = 441797, upload-time = "2025-05-28T19:25:54.947Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/9f/2f509339e89cfa6f6a4c4ff50438db9ca488dec341f7e454adad60150b00/protobuf-6.33.6-cp310-abi3-win32.whl", hash = "sha256:7d29d9b65f8afef196f8334e80d6bc1d5d4adedb449971fefd3723824e6e77d3", size = 425739, upload-time = "2026-03-18T19:04:48.373Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/76/5d/683efcd4798e0030c1bab27374fd13a89f7c2515fb1f3123efdfaa5eab57/protobuf-6.33.6-cp310-abi3-win_amd64.whl", hash = "sha256:0cd27b587afca21b7cfa59a74dcbd48a50f0a6400cfb59391340ad729d91d326", size = 437089, upload-time = "2026-03-18T19:04:50.381Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/01/a3c3ed5cd186f39e7880f8303cc51385a198a81469d53d0fdecf1f64d929/protobuf-6.33.6-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:9720e6961b251bde64edfdab7d500725a2af5280f3f4c87e57c0208376aa8c3a", size = 427737, upload-time = "2026-03-18T19:04:51.866Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ee/90/b3c01fdec7d2f627b3a6884243ba328c1217ed2d978def5c12dc50d328a3/protobuf-6.33.6-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:e2afbae9b8e1825e3529f88d514754e094278bb95eadc0e199751cdd9a2e82a2", size = 324610, upload-time = "2026-03-18T19:04:53.096Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9b/ca/25afc144934014700c52e05103c2421997482d561f3101ff352e1292fb81/protobuf-6.33.6-cp39-abi3-manylinux2014_s390x.whl", hash = "sha256:c96c37eec15086b79762ed265d59ab204dabc53056e3443e702d2681f4b39ce3", size = 339381, upload-time = "2026-03-18T19:04:54.616Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/16/92/d1e32e3e0d894fe00b15ce28ad4944ab692713f2e7f0a99787405e43533a/protobuf-6.33.6-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:e9db7e292e0ab79dd108d7f1a94fe31601ce1ee3f7b79e0692043423020b0593", size = 323436, upload-time = "2026-03-18T19:04:55.768Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c4/72/02445137af02769918a93807b2b7890047c32bfb9f90371cbc12688819eb/protobuf-6.33.6-py3-none-any.whl", hash = "sha256:77179e006c476e69bf8e8ce866640091ec42e1beb80b213c3900006ecfba6901", size = 170656, upload-time = "2026-03-18T19:04:59.826Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f3/6f/6ab8e4bf962fd5570d3deaa2d5c38f0a363f57b4501047b5ebeb83ab1125/protobuf-6.31.1-cp310-abi3-win32.whl", hash = "sha256:7fa17d5a29c2e04b7d90e5e32388b8bfd0e7107cd8e616feef7ed3fa6bdab5c9", size = 423603, upload-time = "2025-05-28T19:25:41.198Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/44/3a/b15c4347dd4bf3a1b0ee882f384623e2063bb5cf9fa9d57990a4f7df2fb6/protobuf-6.31.1-cp310-abi3-win_amd64.whl", hash = "sha256:426f59d2964864a1a366254fa703b8632dcec0790d8862d30034d8245e1cd447", size = 435283, upload-time = "2025-05-28T19:25:44.275Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6a/c9/b9689a2a250264a84e66c46d8862ba788ee7a641cdca39bccf64f59284b7/protobuf-6.31.1-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:6f1227473dc43d44ed644425268eb7c2e488ae245d51c6866d19fe158e207402", size = 425604, upload-time = "2025-05-28T19:25:45.702Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/76/a1/7a5a94032c83375e4fe7e7f56e3976ea6ac90c5e85fac8576409e25c39c3/protobuf-6.31.1-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:a40fc12b84c154884d7d4c4ebd675d5b3b5283e155f324049ae396b95ddebc39", size = 322115, upload-time = "2025-05-28T19:25:47.128Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/b1/b59d405d64d31999244643d88c45c8241c58f17cc887e73bcb90602327f8/protobuf-6.31.1-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:4ee898bf66f7a8b0bd21bce523814e6fbd8c6add948045ce958b73af7e8878c6", size = 321070, upload-time = "2025-05-28T19:25:50.036Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f7/af/ab3c51ab7507a7325e98ffe691d9495ee3d3aa5f589afad65ec920d39821/protobuf-6.31.1-py3-none-any.whl", hash = "sha256:720a6c7e6b77288b85063569baae8536671b39f15cc22037ec7045658d80489e", size = 168724, upload-time = "2025-05-28T19:25:53.926Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1908,22 +1945,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/16/6b/330d8ebae582b30c2959a1ef4c3bc344ebde48c2ff0c3f113c4710735e11/pyright-1.1.409-py3-none-any.whl", hash = "sha256:aa3ea228cab90c845c7a60d28db7a844c04315356392aa09fafcee98c8c22fb3", size = 6438161, upload-time = "2026-04-23T11:02:01.309Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "9.0.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
{ name = "iniconfig" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pluggy" },
|
||||
{ name = "pygments" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/7d/0d/549bd94f1a0a402dc8cf64563a117c0f3765662e2e668477624baeec44d5/pytest-9.0.3.tar.gz", hash = "sha256:b86ada508af81d19edeb213c681b1d48246c1a91d304c6c81a427674c17eb91c", size = 1572165, upload-time = "2026-04-07T17:16:18.027Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d4/24/a372aaf5c9b7208e7112038812994107bc65a84cd00e0354a88c2c77a617/pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9", size = 375249, upload-time = "2026-04-07T17:16:16.13Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "python-dateutil"
|
||||
version = "2.9.0.post0"
|
||||
@@ -2003,6 +2024,29 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "qai-hub"
|
||||
version = "0.50.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "backoff" },
|
||||
{ name = "deprecation" },
|
||||
{ name = "h5py" },
|
||||
{ name = "numpy" },
|
||||
{ name = "packaging" },
|
||||
{ name = "prettytable" },
|
||||
{ name = "protobuf" },
|
||||
{ name = "requests" },
|
||||
{ name = "requests-toolbelt" },
|
||||
{ name = "s3transfer" },
|
||||
{ name = "semver" },
|
||||
{ name = "tqdm" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f1/d8/d25fea29362a762b0d739ca8bfcfbda8b7af7f028813fa4c76a91edabfb1/qai_hub-0.50.0-py3-none-any.whl", hash = "sha256:a0b1e93fc3e358c02151042676779a793fea028d78b09854a3b4c6e0719bc0ce", size = 123503, upload-time = "2026-05-28T23:08:06.19Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "qc-cli"
|
||||
version = "0.1.0"
|
||||
@@ -2011,22 +2055,19 @@ dependencies = [
|
||||
{ name = "aws-cdk-lib" },
|
||||
{ name = "boto3" },
|
||||
{ name = "constructs" },
|
||||
{ name = "mlflow" },
|
||||
{ name = "numpy" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "typer" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
mlflow = [
|
||||
{ name = "mlflow" },
|
||||
{ name = "qai-hub" },
|
||||
{ name = "sagemaker-mlflow" },
|
||||
{ name = "typer" },
|
||||
]
|
||||
|
||||
[package.dev-dependencies]
|
||||
dev = [
|
||||
{ name = "boto3-stubs", extra = ["iam", "s3", "sagemaker"] },
|
||||
{ name = "pyright" },
|
||||
{ name = "pytest" },
|
||||
{ name = "ruff" },
|
||||
{ name = "types-pyyaml" },
|
||||
]
|
||||
@@ -2036,19 +2077,19 @@ requires-dist = [
|
||||
{ name = "aws-cdk-lib", specifier = ">=2.180.0" },
|
||||
{ name = "boto3", specifier = ">=1.34,<1.42" },
|
||||
{ name = "constructs", specifier = ">=10.0.0" },
|
||||
{ name = "mlflow", marker = "extra == 'mlflow'", specifier = ">=3.0" },
|
||||
{ name = "mlflow", specifier = ">=3.0" },
|
||||
{ name = "numpy", specifier = ">=1.26" },
|
||||
{ name = "pydantic", specifier = ">=2.13.3" },
|
||||
{ name = "pyyaml", specifier = ">=6.0.3" },
|
||||
{ name = "sagemaker-mlflow", marker = "extra == 'mlflow'", specifier = ">=0.1.0" },
|
||||
{ name = "qai-hub", specifier = ">=0.49.0" },
|
||||
{ name = "sagemaker-mlflow", specifier = ">=0.4.0" },
|
||||
{ name = "typer", specifier = "==0.25.0" },
|
||||
]
|
||||
provides-extras = ["mlflow"]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
{ name = "boto3-stubs", extras = ["iam", "s3", "sagemaker"] },
|
||||
{ name = "pyright", specifier = ">=1.1.409" },
|
||||
{ name = "pytest", specifier = ">=8.0" },
|
||||
{ name = "ruff", specifier = ">=0.4" },
|
||||
{ name = "types-pyyaml" },
|
||||
]
|
||||
@@ -2068,6 +2109,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a0/f4/c67b0b3f1b9245e8d266f0f112c500d50e5b4e83cb6f3b71b6528104182a/requests-2.34.2-py3-none-any.whl", hash = "sha256:2a0d60c172f83ac6ab31e4554906c0f3b3588d37b5cb939b1c061f4907e278e0", size = 73075, upload-time = "2026-05-14T19:25:26.443Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "requests-toolbelt"
|
||||
version = "1.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "requests" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f3/61/d7545dafb7ac2230c70d38d31cbfe4cc64f7144dc41f6e4e4b78ecd9f5bb/requests-toolbelt-1.0.0.tar.gz", hash = "sha256:7681a0a3d047012b5bdc0ee37d7f8f07ebe76ab08caeccfc3921ce23c88d5bc6", size = 206888, upload-time = "2023-05-01T04:11:33.229Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3f/51/d4db610ef29373b879047326cbf6fa98b6c1969d6f6dc423279de2b1be2c/requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06", size = 54481, upload-time = "2023-05-01T04:11:28.427Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rich"
|
||||
version = "15.0.0"
|
||||
@@ -2220,6 +2273,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/07/39/338d9219c4e87f3e708f18857ecd24d22a0c3094752393319553096b98af/scipy-1.17.1-cp314-cp314t-win_arm64.whl", hash = "sha256:200e1050faffacc162be6a486a984a0497866ec54149a01270adc8a59b7c7d21", size = 25489165, upload-time = "2026-02-23T00:22:29.563Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "semver"
|
||||
version = "3.0.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/72/d1/d3159231aec234a59dd7d601e9dd9fe96f3afff15efd33c1070019b26132/semver-3.0.4.tar.gz", hash = "sha256:afc7d8c584a5ed0a11033af086e8af226a9c0b206f313e0301f8dd7b6b589602", size = 269730, upload-time = "2025-01-24T13:19:27.617Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a6/24/4d91e05817e92e3a61c8a21e08fd0f390f5301f1c448b137c57c4bc6e543/semver-3.0.4-py3-none-any.whl", hash = "sha256:9c824d87ba7f7ab4a1890799cec8596f15c1241cb473404ea1cb0c55e4b04746", size = 17912, upload-time = "2025-01-24T13:19:24.949Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "shellingham"
|
||||
version = "1.5.4"
|
||||
@@ -2327,6 +2389,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638, upload-time = "2025-03-13T13:49:21.846Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tqdm"
|
||||
version = "4.67.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/09/a9/6ba95a270c6f1fbcd8dac228323f2777d886cb206987444e4bce66338dd4/tqdm-4.67.3.tar.gz", hash = "sha256:7d825f03f89244ef73f1d4ce193cb1774a8179fd96f31d7e1dcde62092b960bb", size = 169598, upload-time = "2026-02-03T17:35:53.048Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/16/e1/3079a9ff9b8e11b846c6ac5c8b5bfb7ff225eee721825310c91b3b50304f/tqdm-4.67.3-py3-none-any.whl", hash = "sha256:ee1e4c0e59148062281c49d80b25b67771a127c85fc9676d3be5f243206826bf", size = 78374, upload-time = "2026-02-03T17:35:50.982Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typeguard"
|
||||
version = "2.13.3"
|
||||
|
||||
Reference in New Issue
Block a user