Compare commits
13 Commits
main
...
b411be7904
| Author | SHA1 | Date | |
|---|---|---|---|
| b411be7904 | |||
| 090be14a6a | |||
| d3ebd2cc5f | |||
| 57a8a0a9c4 | |||
| a43c792cfd | |||
| cf6a561e2f | |||
| 416e51901d | |||
| 556797cf13 | |||
| 19fef8638b | |||
| 58681cef82 | |||
| e1c8d6574f | |||
| 35d25d8967 | |||
| b907a74525 |
65
README.md
65
README.md
@@ -67,8 +67,7 @@ sagemaker:
|
|||||||
hyperparameters: {}
|
hyperparameters: {}
|
||||||
|
|
||||||
aihub:
|
aihub:
|
||||||
device:
|
device: Samsung Galaxy S25 (Family)
|
||||||
name: Samsung Galaxy S25 (Family)
|
|
||||||
target_runtime: tflite
|
target_runtime: tflite
|
||||||
input_specs: {} # Required before running qc-cli ai-hub commands
|
input_specs: {} # Required before running qc-cli ai-hub commands
|
||||||
job_name: null # Optional prefix for AI Hub Workbench jobs
|
job_name: null # Optional prefix for AI Hub Workbench jobs
|
||||||
@@ -105,15 +104,15 @@ mlflow:
|
|||||||
tracking_server_name: your-tracking-server-name
|
tracking_server_name: your-tracking-server-name
|
||||||
```
|
```
|
||||||
|
|
||||||
When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. Training metrics can be upload with `train start --upload-metrics` or `mlflow upload-metrics`.
|
When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. `train status` finalizes that run once the job reaches a terminal state and registers completed model artifacts as experiment model versions using the `experiment-latest` MLflow alias. An experiment version is an immutable trained-source artifact; it records that training produced a model, not that the model is better than earlier versions or ready for release.
|
||||||
|
|
||||||
To open the managed SageMaker MLflow UI, request a fresh presigned URL:
|
To open the managed SageMaker MLflow UI, request a fresh presigned URL:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
qc-cli mlflow open --config config.yaml
|
qc-cli infra mlflow-url --config config.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
This opens a browser to a fresh presigned URL. It works for `mode: create` and for `mode: existing` when the existing server is managed by Amazon SageMaker. In `create` mode, the command uses the CLI-managed tracking server name. In `existing` mode, it uses `mlflow.tracking_server_name`. If the existing MLflow server is external to SageMaker, open it with that server's own URL instead.
|
This works for `mode: create` and for `mode: existing` when the existing server is managed by Amazon SageMaker. In `create` mode, the command uses the CLI-managed tracking server name. In `existing` mode, it uses `mlflow.tracking_server_name`. If the existing MLflow server is external to SageMaker, open it with that server's own URL instead.
|
||||||
|
|
||||||
## Commands
|
## Commands
|
||||||
|
|
||||||
@@ -132,6 +131,7 @@ qc-cli infra setup Deploy the CDK stack
|
|||||||
qc-cli infra setup --no-bootstrap Deploy without running CDK bootstrap
|
qc-cli infra setup --no-bootstrap Deploy without running CDK bootstrap
|
||||||
qc-cli infra setup --cloudformation-execution-policy <arn> Set CDK bootstrap execution policy ARN
|
qc-cli infra setup --cloudformation-execution-policy <arn> Set CDK bootstrap execution policy ARN
|
||||||
qc-cli infra status Show CDK stack/resource status
|
qc-cli infra status Show CDK stack/resource status
|
||||||
|
qc-cli infra mlflow-url Print a presigned MLflow UI URL
|
||||||
qc-cli infra destroy Destroy stack, retaining S3 data
|
qc-cli infra destroy Destroy stack, retaining S3 data
|
||||||
qc-cli infra destroy --yes Destroy stack without confirmation
|
qc-cli infra destroy --yes Destroy stack without confirmation
|
||||||
qc-cli infra destroy --delete-bucket-data Destroy stack and delete S3 data
|
qc-cli infra destroy --delete-bucket-data Destroy stack and delete S3 data
|
||||||
@@ -143,15 +143,6 @@ qc-cli infra destroy --delete-bucket-data Destroy stack and delete S3 data
|
|||||||
qc-cli infra setup --cloudformation-execution-policy arn:aws:iam::aws:policy/PowerUserAccess
|
qc-cli infra setup --cloudformation-execution-policy arn:aws:iam::aws:policy/PowerUserAccess
|
||||||
```
|
```
|
||||||
|
|
||||||
### `mlflow`
|
|
||||||
|
|
||||||
```
|
|
||||||
qc-cli mlflow open Open a presigned MLflow UI URL
|
|
||||||
qc-cli mlflow upload-metrics [job-name] Upload completed training metrics
|
|
||||||
```
|
|
||||||
|
|
||||||
`mlflow upload-metrics` defaults to the last submitted training job. It creates or recovers the job's MLflow run, imports `training_metrics.json` from the SageMaker model artifact, and records successful upload in `.qc-cli.json`. Use `--force` to upload the metrics again.
|
|
||||||
|
|
||||||
### `upload`
|
### `upload`
|
||||||
|
|
||||||
```
|
```
|
||||||
@@ -166,7 +157,6 @@ Uploads use `s3.bucket` and `s3.data_prefix` from `config.yaml`. File uploads de
|
|||||||
|
|
||||||
```
|
```
|
||||||
qc-cli train start Submit a SageMaker training job
|
qc-cli train start Submit a SageMaker training job
|
||||||
qc-cli train start --upload-metrics Submit, wait, and upload metrics
|
|
||||||
qc-cli train status [job-name] Show job status; defaults to the last submitted job
|
qc-cli train status [job-name] Show job status; defaults to the last submitted job
|
||||||
qc-cli train list List recent training jobs
|
qc-cli train list List recent training jobs
|
||||||
qc-cli train list --limit 3 Show a custom number of recent jobs
|
qc-cli train list --limit 3 Show a custom number of recent jobs
|
||||||
@@ -174,8 +164,6 @@ qc-cli train list --limit 3 Show a custom number of recent jobs
|
|||||||
|
|
||||||
`train start` uses `s3://<bucket>/<data_prefix>/` as the training channel and writes outputs under `s3://<bucket>/<model_prefix>/`. If `sagemaker.training.source_dir` is set, the CLI packages that directory, uploads it beside the job output prefix, and passes `sagemaker_program`/`sagemaker_submit_directory` to the SageMaker container.
|
`train start` uses `s3://<bucket>/<data_prefix>/` as the training channel and writes outputs under `s3://<bucket>/<model_prefix>/`. If `sagemaker.training.source_dir` is set, the CLI packages that directory, uploads it beside the job output prefix, and passes `sagemaker_program`/`sagemaker_submit_directory` to the SageMaker container.
|
||||||
|
|
||||||
`train start --upload-metrics` checks SageMaker every 30 seconds by default, then uploads metrics after completion. Use `--poll-interval <seconds>` to choose another positive interval. Stopping the local command does not stop the SageMaker job.
|
|
||||||
|
|
||||||
The expected output artifact is SageMaker’s `model.tar.gz`, normally containing the trained model file your container writes to `/opt/ml/model`.
|
The expected output artifact is SageMaker’s `model.tar.gz`, normally containing the trained model file your container writes to `/opt/ml/model`.
|
||||||
|
|
||||||
### `ai-hub`
|
### `ai-hub`
|
||||||
@@ -183,29 +171,16 @@ The expected output artifact is SageMaker’s `model.tar.gz`, normally containin
|
|||||||
```
|
```
|
||||||
qc-cli ai-hub upload <calibration.npz|calibration-dir> <inputs.npz|inputs.npy>
|
qc-cli ai-hub upload <calibration.npz|calibration-dir> <inputs.npz|inputs.npy>
|
||||||
qc-cli ai-hub upload <calibration> <inputs> --from-step validate
|
qc-cli ai-hub upload <calibration> <inputs> --from-step validate
|
||||||
qc-cli ai-hub optimize [--onnx-path PATH] [--model-s3-uri URI] [--from-job NAME]
|
qc-cli ai-hub quantize <calibration.npz|calibration-dir> [--onnx-path PATH] [--model-s3-uri URI] [--from-job NAME]
|
||||||
qc-cli ai-hub quantize <calibration.npz|calibration-dir> [--model-id ID] [--onnx-path PATH] [--model-s3-uri URI] [--from-job NAME]
|
|
||||||
qc-cli ai-hub compile [--model-id ID] [--onnx-path PATH] [--model-s3-uri URI] [--from-job NAME]
|
qc-cli ai-hub compile [--model-id ID] [--onnx-path PATH] [--model-s3-uri URI] [--from-job NAME]
|
||||||
qc-cli ai-hub validate <inputs.npz|inputs.npy> [--model-id ID] [--input-name NAME]
|
qc-cli ai-hub validate <inputs.npz|inputs.npy> [--model-id ID] [--input-name NAME]
|
||||||
qc-cli ai-hub profile [--model-id ID]
|
qc-cli ai-hub profile [--model-id ID]
|
||||||
qc-cli ai-hub download [--model-id ID] [--output PATH]
|
qc-cli ai-hub download [--model-id ID] [--output PATH]
|
||||||
```
|
```
|
||||||
|
|
||||||
`ai-hub upload` optimizes to ONNX, quantizes, validates, and profiles. When `aihub.target_runtime` is not `onnx`, it also compiles the quantized model to that deployment runtime. The initial ONNX optimization gives external models Workbench provenance and applies compiler optimization passes before quantization.
|
`ai-hub upload` runs the four Workbench upload steps in order: quantize, compile, validate, and profile. Use `--from-step compile`, `--from-step validate`, or `--from-step profile` to resume from saved local state after a completed earlier step.
|
||||||
|
|
||||||
Resume behavior:
|
`ai-hub compile` resolves model sources in this order: `--model-id`, explicit source options (`--onnx-path`, `--model-s3-uri`, `--from-job`), last quantized model from state, then the last training job from local state. `ai-hub download` is separate because downloading the optimized artifact is outside the four-step Workbench upload loop.
|
||||||
|
|
||||||
```text
|
|
||||||
--from-step optimize Run optimize, quantize, optional final compile, validate, and profile.
|
|
||||||
--from-step quantize Quantize the last optimized ONNX, then optionally compile, validate, and profile.
|
|
||||||
--from-step compile Skip optimize and quantize; finalize the last quantized model for the target runtime.
|
|
||||||
--from-step validate Skip optimize, quantize, and compile; validate the last compiled model.
|
|
||||||
--from-step profile Skip optimize, quantize, compile, and validate; profile the last compiled model.
|
|
||||||
```
|
|
||||||
|
|
||||||
When a step runs in the current command, `upload` passes its returned model ID directly to the next step. When a step is skipped, the next step resolves the needed model ID from `.qc-cli.json`. This avoids re-running earlier AI Hub jobs when you only need to continue from a later step.
|
|
||||||
|
|
||||||
`ai-hub optimize` compiles an external model with `--target_runtime onnx`. `ai-hub quantize` uses an explicit `--model-id`, the last optimized ONNX model, or an explicit/local model source in that order. `ai-hub compile` resolves model sources in this order: `--model-id`, explicit source options, last quantized model, then the last training job. For `target_runtime: onnx`, upload treats the quantized ONNX as the final model and skips a redundant second compile. `ai-hub download` remains separate because downloading is outside the Workbench processing loop.
|
|
||||||
|
|
||||||
AI Hub authentication currently uses the local `qai-hub` SDK configuration. A planned follow-up is to support AWS Systems Manager Parameter Store `SecureString` for team-managed tokens, where `config.yaml` stores only a parameter name such as `/qc-cli/aihub/token`, AWS KMS encrypts the token at rest, and the CLI retrieves it at runtime with `ssm:GetParameter` plus `kms:Decrypt` permissions.
|
AI Hub authentication currently uses the local `qai-hub` SDK configuration. A planned follow-up is to support AWS Systems Manager Parameter Store `SecureString` for team-managed tokens, where `config.yaml` stores only a parameter name such as `/qc-cli/aihub/token`, AWS KMS encrypts the token at rest, and the CLI retrieves it at runtime with `ssm:GetParameter` plus `kms:Decrypt` permissions.
|
||||||
|
|
||||||
@@ -216,29 +191,13 @@ The CLI uses neutral experiment naming for trained artifacts and reserves releas
|
|||||||
Current behavior:
|
Current behavior:
|
||||||
|
|
||||||
1. `qc-cli train start` submits a SageMaker training job.
|
1. `qc-cli train start` submits a SageMaker training job.
|
||||||
2. `qc-cli train status` reads and displays SageMaker status only; it does not contact MLflow.
|
2. `qc-cli train status` finalizes the MLflow run after the job reaches a terminal state.
|
||||||
3. `qc-cli train start --upload-metrics` polls every 30 seconds by default, then uploads per-epoch metrics after completion.
|
3. If the job completed and `mlflow.register_trained_models` is enabled, the SageMaker `model.tar.gz` is registered as a new MLflow model version with:
|
||||||
4. `qc-cli mlflow upload-metrics [job-name]` uploads or retries metrics for an existing completed job.
|
|
||||||
5. The metrics upload workflow finalizes the MLflow run and, when `mlflow.register_trained_models` is enabled, registers the SageMaker `model.tar.gz` as a new MLflow model version with:
|
|
||||||
- `qc_cli.stage=experiment`
|
- `qc_cli.stage=experiment`
|
||||||
- `qc_cli.artifact_kind=trained_source`
|
- `qc_cli.artifact_kind=trained_source`
|
||||||
- `qc_cli.source=sagemaker`
|
- `qc_cli.source=sagemaker`
|
||||||
6. The MLflow alias `experiment-latest` points at the most recently registered experiment version.
|
4. The MLflow alias `experiment-latest` points at the most recently registered experiment version.
|
||||||
7. AI Hub upload commands create deployable derived artifacts from a trained-source experiment or local ONNX model.
|
5. AI Hub upload commands create deployable derived artifacts from a trained-source experiment or local ONNX model.
|
||||||
|
|
||||||
Training scripts can include a `training_metrics.json` file in the SageMaker model directory. When present, the explicit metrics upload command logs its ordered metrics to the associated MLflow run using each epoch as the MLflow step and stores the JSON as a run artifact:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"schema_version": 1,
|
|
||||||
"steps": [
|
|
||||||
{"step": 0, "metrics": {"val.precision": 0.72, "val.recall": 0.68}}
|
|
||||||
],
|
|
||||||
"summary": {"summary.best_epoch": 0}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Metric names must be non-empty strings, values must be finite numbers, and steps must be non-negative, unique, and strictly increasing. If the file is missing, the command uploads the final metrics reported by SageMaker and continues model registration without per-epoch history. A malformed metrics artifact still fails the upload command without affecting the trained model or model registration.
|
|
||||||
|
|
||||||
Future release aliases such as `v1` or `production` can point at a selected deployable artifact.
|
Future release aliases such as `v1` or `production` can point at a selected deployable artifact.
|
||||||
|
|
||||||
|
|||||||
117
examples/ai-hub/README.md
Normal file
117
examples/ai-hub/README.md
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
# Qualcomm AI Hub Example
|
||||||
|
|
||||||
|
This example takes the ONNX model produced by the SageMaker training example and runs the Qualcomm AI Hub upload workflow:
|
||||||
|
|
||||||
|
1. Quantize
|
||||||
|
2. Compile
|
||||||
|
3. Validate
|
||||||
|
4. Profile
|
||||||
|
5. Download the compiled artifact
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
Run the training example first and wait for it to complete:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash examples/training/run_training.sh --config config.yaml --wait
|
||||||
|
```
|
||||||
|
|
||||||
|
If the dataset is already uploaded to S3, use:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash examples/training/run_training.sh --config config.yaml --skip-upload --wait
|
||||||
|
```
|
||||||
|
|
||||||
|
The training artifact must contain a static-shape `model.onnx`. The training example exports an input named `input` with shape `1x3x160x160`.
|
||||||
|
|
||||||
|
Your `config.yaml` must include AI Hub settings:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
aihub:
|
||||||
|
device: Samsung Galaxy S25 (Family)
|
||||||
|
target_runtime: tflite
|
||||||
|
input_specs:
|
||||||
|
input: [[1, 3, 160, 160], float32]
|
||||||
|
output_dir: build/qai-hub
|
||||||
|
```
|
||||||
|
|
||||||
|
You also need local Qualcomm AI Hub SDK authentication configured.
|
||||||
|
|
||||||
|
## Prepare Inputs
|
||||||
|
|
||||||
|
AI Hub does not consume the raw JPG training images directly. It needs NumPy tensors that match the ONNX model input shape and preprocessing.
|
||||||
|
|
||||||
|
Generate calibration and validation inputs:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python examples/ai-hub/prepare_inputs.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This writes:
|
||||||
|
|
||||||
|
```text
|
||||||
|
examples/training/data/aihub_calibration/*.npy
|
||||||
|
examples/training/data/inputs.npz
|
||||||
|
```
|
||||||
|
|
||||||
|
The script applies the same image preprocessing used by the training example:
|
||||||
|
|
||||||
|
- resize to `160x160`
|
||||||
|
- convert to channel-first `1x3x160x160`
|
||||||
|
- normalize with ImageNet mean and standard deviation
|
||||||
|
|
||||||
|
Useful options:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python examples/ai-hub/prepare_inputs.py \
|
||||||
|
--dataset-dir examples/training/data/flower_photos_sagemaker \
|
||||||
|
--calibration-dir examples/training/data/aihub_calibration \
|
||||||
|
--input-file examples/training/data/inputs.npz \
|
||||||
|
--samples 16
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run AI Hub
|
||||||
|
|
||||||
|
After training completes and inputs are prepared:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash examples/ai-hub/run_ai_hub.sh --config config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
By default, the script uses the last SageMaker training job recorded in `.qc-cli.json`. It downloads that job's `model.tar.gz`, extracts `model.onnx`, runs the AI Hub workflow, and downloads the compiled artifact.
|
||||||
|
|
||||||
|
To use a specific training job:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash examples/ai-hub/run_ai_hub.sh \
|
||||||
|
--config config.yaml \
|
||||||
|
--from-job qc-cli-YYYYMMDD-HHMMSS
|
||||||
|
```
|
||||||
|
|
||||||
|
To resume from a later Workbench step:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash examples/ai-hub/run_ai_hub.sh \
|
||||||
|
--config config.yaml \
|
||||||
|
--from-step validate
|
||||||
|
```
|
||||||
|
|
||||||
|
To skip downloading the compiled artifact:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash examples/ai-hub/run_ai_hub.sh \
|
||||||
|
--config config.yaml \
|
||||||
|
--skip-download
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
If AI Hub reports dynamic input shapes, rerun training with the current training source. AI Hub quantization requires the exported ONNX model to use static input shapes.
|
||||||
|
|
||||||
|
If `run_ai_hub.sh` reports missing calibration or input files, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python examples/ai-hub/prepare_inputs.py
|
||||||
|
```
|
||||||
|
|
||||||
|
If validation fails with a missing input name, make sure `config.yaml` and the generated `.npz` both use `input` as the input name.
|
||||||
75
examples/ai-hub/prepare_inputs.py
Executable file
75
examples/ai-hub/prepare_inputs.py
Executable file
@@ -0,0 +1,75 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Prepare Qualcomm AI Hub calibration and validation inputs for the training example."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png"}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("examples/training/data/flower_photos_sagemaker"),
|
||||||
|
help="ImageFolder-style dataset used for training.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--calibration-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("examples/training/data/aihub_calibration"),
|
||||||
|
help="Directory where .npy calibration samples will be written.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input-file",
|
||||||
|
type=Path,
|
||||||
|
default=Path("examples/training/data/inputs.npz"),
|
||||||
|
help="Validation .npz input file for qc-cli ai-hub validate.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--input-name", default="input", help="ONNX input name.")
|
||||||
|
parser.add_argument("--image-size", type=int, default=160, help="Square image size used by training.")
|
||||||
|
parser.add_argument("--samples", type=int, default=16, help="Number of calibration samples to write.")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_image(path: Path, image_size: int) -> np.ndarray:
|
||||||
|
image = Image.open(path).convert("RGB").resize((image_size, image_size), Image.Resampling.BILINEAR)
|
||||||
|
array = np.asarray(image, dtype=np.float32) / 255.0
|
||||||
|
array = np.transpose(array, (2, 0, 1))
|
||||||
|
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)[:, None, None]
|
||||||
|
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)[:, None, None]
|
||||||
|
return ((array - mean) / std)[None, ...].astype("float32")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
images = sorted(p for p in args.dataset_dir.rglob("*") if p.suffix.lower() in IMAGE_EXTENSIONS)
|
||||||
|
if not images:
|
||||||
|
raise SystemExit(f"No images found under {args.dataset_dir}")
|
||||||
|
if args.samples < 1:
|
||||||
|
raise SystemExit("--samples must be at least 1")
|
||||||
|
|
||||||
|
args.calibration_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
args.input_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
sample_count = min(args.samples, len(images))
|
||||||
|
prepared = []
|
||||||
|
for index, image_path in enumerate(images[:sample_count]):
|
||||||
|
sample = preprocess_image(image_path, args.image_size)
|
||||||
|
np.save(args.calibration_dir / f"sample_{index:03d}.npy", sample)
|
||||||
|
prepared.append(sample)
|
||||||
|
|
||||||
|
np.savez(args.input_file, **{args.input_name: prepared[0]})
|
||||||
|
print(f"Wrote {sample_count} calibration samples to {args.calibration_dir}")
|
||||||
|
print(f"Wrote validation input to {args.input_file}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
156
examples/ai-hub/run_ai_hub.sh
Executable file
156
examples/ai-hub/run_ai_hub.sh
Executable file
@@ -0,0 +1,156 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
CONFIG_PATH="config.yaml"
|
||||||
|
CALIBRATION_PATH="examples/training/data/aihub_calibration"
|
||||||
|
INPUT_FILE="examples/training/data/inputs.npz"
|
||||||
|
FROM_STEP="quantize"
|
||||||
|
FROM_JOB=""
|
||||||
|
MODEL_S3_URI=""
|
||||||
|
ONNX_PATH=""
|
||||||
|
INPUT_NAME=""
|
||||||
|
DOWNLOAD=true
|
||||||
|
OUTPUT_PATH=""
|
||||||
|
|
||||||
|
usage() {
|
||||||
|
cat <<EOF
|
||||||
|
Usage: $0 [options]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--config PATH Path to qc-cli config file. Default: config.yaml
|
||||||
|
--calibration PATH Calibration .npz file or directory of .npy samples.
|
||||||
|
Default: ${CALIBRATION_PATH}
|
||||||
|
--input-file PATH Validation .npz or .npy inputs. Default: ${INPUT_FILE}
|
||||||
|
--from-step STEP Resume upload from: quantize, compile, validate, profile.
|
||||||
|
Default: ${FROM_STEP}
|
||||||
|
--from-job NAME SageMaker training job whose model artifact should upload.
|
||||||
|
Defaults to the last training job in local qc-cli state.
|
||||||
|
--model-s3-uri URI S3 URI of model.tar.gz to upload.
|
||||||
|
--onnx-path PATH Local ONNX path or ONNX path inside extracted artifact.
|
||||||
|
--input-name NAME Input name for .npy validation files.
|
||||||
|
--skip-download Do not download the compiled AI Hub artifact after upload.
|
||||||
|
--output PATH Destination file for ai-hub download.
|
||||||
|
-h, --help Show this help.
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
while [[ $# -gt 0 ]]; do
|
||||||
|
case "$1" in
|
||||||
|
--config)
|
||||||
|
CONFIG_PATH="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--calibration)
|
||||||
|
CALIBRATION_PATH="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--input-file)
|
||||||
|
INPUT_FILE="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--from-step)
|
||||||
|
FROM_STEP="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--from-job)
|
||||||
|
FROM_JOB="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--model-s3-uri)
|
||||||
|
MODEL_S3_URI="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--onnx-path)
|
||||||
|
ONNX_PATH="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--input-name)
|
||||||
|
INPUT_NAME="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--skip-download)
|
||||||
|
DOWNLOAD=false
|
||||||
|
shift
|
||||||
|
;;
|
||||||
|
--output)
|
||||||
|
OUTPUT_PATH="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-h|--help)
|
||||||
|
usage
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unknown option: $1" >&2
|
||||||
|
usage >&2
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
if [[ ! -f "${CONFIG_PATH}" ]]; then
|
||||||
|
echo "Config not found: ${CONFIG_PATH}" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
case "${FROM_STEP}" in
|
||||||
|
quantize|compile|validate|profile)
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "--from-step must be one of: quantize, compile, validate, profile" >&2
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
if [[ ! -e "${CALIBRATION_PATH}" ]]; then
|
||||||
|
echo "Calibration path not found: ${CALIBRATION_PATH}" >&2
|
||||||
|
echo "Pass --calibration with a .npz file or directory of .npy samples." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ ! -f "${INPUT_FILE}" ]]; then
|
||||||
|
echo "Input file not found: ${INPUT_FILE}" >&2
|
||||||
|
echo "Pass --input-file with a validation .npz or .npy file." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
run() {
|
||||||
|
echo "+ $*"
|
||||||
|
"$@"
|
||||||
|
}
|
||||||
|
|
||||||
|
UPLOAD_ARGS=(
|
||||||
|
"${CALIBRATION_PATH}"
|
||||||
|
"${INPUT_FILE}"
|
||||||
|
--from-step "${FROM_STEP}"
|
||||||
|
--config "${CONFIG_PATH}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if [[ -n "${FROM_JOB}" ]]; then
|
||||||
|
UPLOAD_ARGS+=(--from-job "${FROM_JOB}")
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ -n "${MODEL_S3_URI}" ]]; then
|
||||||
|
UPLOAD_ARGS+=(--model-s3-uri "${MODEL_S3_URI}")
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ -n "${ONNX_PATH}" ]]; then
|
||||||
|
UPLOAD_ARGS+=(--onnx-path "${ONNX_PATH}")
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ -n "${INPUT_NAME}" ]]; then
|
||||||
|
UPLOAD_ARGS+=(--input-name "${INPUT_NAME}")
|
||||||
|
fi
|
||||||
|
|
||||||
|
run uv run qc-cli ai-hub upload "${UPLOAD_ARGS[@]}"
|
||||||
|
|
||||||
|
if [[ "${DOWNLOAD}" == false ]]; then
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
DOWNLOAD_ARGS=(--config "${CONFIG_PATH}")
|
||||||
|
if [[ -n "${OUTPUT_PATH}" ]]; then
|
||||||
|
DOWNLOAD_ARGS+=(--output "${OUTPUT_PATH}")
|
||||||
|
fi
|
||||||
|
|
||||||
|
run uv run qc-cli ai-hub download "${DOWNLOAD_ARGS[@]}"
|
||||||
@@ -1,285 +0,0 @@
|
|||||||
# YOLO26 Electric Meter Detection Example
|
|
||||||
|
|
||||||
This example trains a YOLO26 object detection model on the Roboflow Universe electric meter dataset using the existing `qc-cli` SageMaker training flow.
|
|
||||||
|
|
||||||
The workflow is intentionally command driven. Run each step yourself so you can inspect the dataset, update `config.yaml`, and decide when to submit the SageMaker job.
|
|
||||||
|
|
||||||
Dataset:
|
|
||||||
|
|
||||||
```text
|
|
||||||
https://universe.roboflow.com/kemals-workspace-kbc8l/electric-meter-detection-o4tfi/dataset/1
|
|
||||||
```
|
|
||||||
|
|
||||||
## Prerequisites
|
|
||||||
|
|
||||||
- Install or sync the project dependencies: `uv sync`
|
|
||||||
- The virtual environment is activated.
|
|
||||||
- AWS credentials configured for the profile in `config.yaml`
|
|
||||||
- Infrastructure already deployed with `qc-cli infra setup`
|
|
||||||
|
|
||||||
## 1. Download The Dataset
|
|
||||||
|
|
||||||
Register or sign in to Roboflow, then open the dataset page:
|
|
||||||
|
|
||||||
```text
|
|
||||||
https://universe.roboflow.com/kemals-workspace-kbc8l/electric-meter-detection-o4tfi/dataset/1
|
|
||||||
```
|
|
||||||
|
|
||||||
Download the dataset in YOLOv26 format from the Roboflow UI, then extract the downloaded archive into:
|
|
||||||
|
|
||||||
```text
|
|
||||||
examples/meter-detection/data/electric-meter-detection
|
|
||||||
```
|
|
||||||
|
|
||||||
The `data.yaml` file should be directly under that folder:
|
|
||||||
|
|
||||||
```text
|
|
||||||
examples/meter-detection/data/electric-meter-detection/data.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
Do not move `data.yaml` into the `train/` split folder.
|
|
||||||
|
|
||||||
After extracting, confirm the dataset has a YOLO data file and image splits:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
find examples/meter-detection/data/electric-meter-detection -maxdepth 2 -type d | sort
|
|
||||||
find examples/meter-detection/data/electric-meter-detection -name data.yaml -print
|
|
||||||
```
|
|
||||||
|
|
||||||
Open `examples/meter-detection/data/electric-meter-detection/data.yaml` and make sure the split paths are relative to that folder:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
path: .
|
|
||||||
train: train/images
|
|
||||||
val: valid/images
|
|
||||||
test: test/images
|
|
||||||
```
|
|
||||||
|
|
||||||
If your downloaded dataset does not include a `test/` folder, remove the `test:` line.
|
|
||||||
|
|
||||||
The expected layout is similar to:
|
|
||||||
|
|
||||||
```text
|
|
||||||
examples/meter-detection/data/electric-meter-detection/
|
|
||||||
data.yaml
|
|
||||||
train/
|
|
||||||
valid/
|
|
||||||
test/
|
|
||||||
```
|
|
||||||
|
|
||||||
## 2. Configure SageMaker Training
|
|
||||||
|
|
||||||
Update `config.yaml` so the training section points at this example's source directory:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
sagemaker:
|
|
||||||
training:
|
|
||||||
image_uri: 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.6-cpu-py312-ubuntu22.04-sagemaker-v1
|
|
||||||
instance_type: ml.g4dn.xlarge
|
|
||||||
instance_count: 1
|
|
||||||
source_dir: examples/meter-detection/source
|
|
||||||
entry_point: train.py
|
|
||||||
hyperparameters:
|
|
||||||
model: yolo26n.pt
|
|
||||||
epochs: 25
|
|
||||||
imgsz: 640
|
|
||||||
batch: 16
|
|
||||||
workers: 2
|
|
||||||
```
|
|
||||||
|
|
||||||
Use `yolo26n.pt` for a lightweight first YOLO26 run. If those weights are unavailable in the installed Ultralytics package, use `yolo11n.pt` as the established fallback:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
model: yolo11n.pt
|
|
||||||
```
|
|
||||||
|
|
||||||
The `source/requirements.txt` file is installed by the SageMaker PyTorch container before running `train.py`.
|
|
||||||
|
|
||||||
For a CPU smoke test, use a CPU instance and reduce the workload:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
sagemaker:
|
|
||||||
training:
|
|
||||||
image_uri: 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.6-cpu-py312-ubuntu22.04-sagemaker-v1
|
|
||||||
instance_type: ml.m4.xlarge
|
|
||||||
instance_count: 1
|
|
||||||
source_dir: examples/meter-detection/source
|
|
||||||
entry_point: train.py
|
|
||||||
hyperparameters:
|
|
||||||
model: yolo26n.pt
|
|
||||||
epochs: 1
|
|
||||||
imgsz: 320
|
|
||||||
batch: 4
|
|
||||||
workers: 2
|
|
||||||
```
|
|
||||||
|
|
||||||
## 3. Check Infrastructure
|
|
||||||
|
|
||||||
Confirm the CLI can see the configured SageMaker role and S3 bucket:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
qc-cli infra status
|
|
||||||
```
|
|
||||||
|
|
||||||
## 4. Upload The Dataset
|
|
||||||
|
|
||||||
Upload the downloaded Roboflow dataset to the `s3.data_prefix` configured in `config.yaml`:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
qc-cli upload examples/meter-detection/data/electric-meter-detection
|
|
||||||
```
|
|
||||||
|
|
||||||
Directory uploads preserve paths relative to the uploaded directory, so SageMaker receives the dataset root with `data.yaml` plus the split directories.
|
|
||||||
|
|
||||||
In SageMaker, this uploaded dataset root is mounted at `/opt/ml/input/data/train`. That `train` path is the SageMaker channel name, not the YOLO `train/` split folder.
|
|
||||||
|
|
||||||
## 5. Start Training
|
|
||||||
|
|
||||||
Submit the SageMaker training job:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
qc-cli train start
|
|
||||||
```
|
|
||||||
|
|
||||||
The command prints the submitted SageMaker job name. Check progress with:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
qc-cli train status
|
|
||||||
```
|
|
||||||
|
|
||||||
Or pass the job name explicitly:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
qc-cli train status qc-cli-YYYYMMDD-HHMMSS
|
|
||||||
```
|
|
||||||
|
|
||||||
To submit the job, wait for completion, and automatically import metrics and register the model, run:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
qc-cli train start --upload-metrics
|
|
||||||
```
|
|
||||||
|
|
||||||
The default polling interval is 30 seconds. It can be changed with `--poll-interval <seconds>`.
|
|
||||||
|
|
||||||
The metrics can be also submitted using:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
qc-cli mlflow upload-metrics
|
|
||||||
```
|
|
||||||
|
|
||||||
## SageMaker Outputs
|
|
||||||
|
|
||||||
When the job completes, SageMaker packages the files written under `/opt/ml/model` into `model.tar.gz`.
|
|
||||||
|
|
||||||
This example writes:
|
|
||||||
|
|
||||||
```text
|
|
||||||
best.pt
|
|
||||||
model.onnx
|
|
||||||
metrics.json
|
|
||||||
training_metrics.json
|
|
||||||
```
|
|
||||||
|
|
||||||
The archive is stored under the configured `s3.model_prefix`.
|
|
||||||
|
|
||||||
The `mlflow upload-metrics` command imports `training_metrics.json`, which provides per-epoch training and validation
|
|
||||||
losses, precision, recall, mAP@0.50, mAP@0.50:0.95, and learning rates. For object detection, mAP and precision/recall
|
|
||||||
are more meaningful than classification accuracy when assessing model quality.
|
|
||||||
|
|
||||||
## 6. Configure Qualcomm AI Hub
|
|
||||||
|
|
||||||
Authenticate with Qualcomm AI Hub:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
qai-hub configure --api_token
|
|
||||||
```
|
|
||||||
|
|
||||||
Add AI Hub settings to `config.yaml`. The input name and image size must match the ONNX model exported by this example:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
aihub:
|
|
||||||
device:
|
|
||||||
name: Dragonwing IQ-9075 EVK
|
|
||||||
target_runtime: onnx
|
|
||||||
input_specs:
|
|
||||||
images: [[1, 3, 640, 640], float32]
|
|
||||||
job_name: meter-detection
|
|
||||||
model_name: meter-detection
|
|
||||||
output_dir: build/qai-hub/meter-detection
|
|
||||||
```
|
|
||||||
|
|
||||||
The ONNX graph is the source of truth. The export normally uses the same value as `sagemaker.training.hyperparameters.imgsz`, but changing `config.yaml` after training does not resize an existing model. For example, a model exported with `imgsz: 320` requires `images: [[1, 3, 320, 320], float32]`.
|
|
||||||
|
|
||||||
## 7. Prepare AI Hub Inputs
|
|
||||||
|
|
||||||
Generate calibration samples and a validation input from the downloaded dataset:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uv run python examples/meter-detection/prepare_aihub_inputs.py --image-size 640
|
|
||||||
```
|
|
||||||
|
|
||||||
This writes:
|
|
||||||
|
|
||||||
```text
|
|
||||||
examples/meter-detection/data/aihub_calibration/*.npy
|
|
||||||
examples/meter-detection/data/inputs.npz
|
|
||||||
```
|
|
||||||
|
|
||||||
The script applies the preprocessing expected by the exported YOLO model: aspect-ratio-preserving letterboxing, RGB channel order, channel-first layout, and pixel values normalized to `[0, 1]`.
|
|
||||||
|
|
||||||
## 8. Upload To Qualcomm AI Hub
|
|
||||||
|
|
||||||
Use the SageMaker job name printed by `qc-cli train start`:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
qc-cli ai-hub upload \
|
|
||||||
examples/meter-detection/data/aihub_calibration \
|
|
||||||
examples/meter-detection/data/inputs.npz \
|
|
||||||
--from-job qc-cli-YYYYMMDD-HHMMSS
|
|
||||||
```
|
|
||||||
|
|
||||||
The command downloads the job's `model.tar.gz`, finds `model.onnx`, and runs the following AI Hub workflow:
|
|
||||||
|
|
||||||
1. Compile the external ONNX to a Workbench-optimized ONNX model.
|
|
||||||
2. Quantize the optimized ONNX model.
|
|
||||||
3. Compile the quantized model when the configured deployment runtime is not `onnx`.
|
|
||||||
4. Validate and profile the final model.
|
|
||||||
|
|
||||||
The training example sanitizes the Ultralytics ONNX export before saving `model.onnx`. This removes graph input or output names, such as `output0`, that are duplicated in the ONNX `value_info` metadata and rejected by AI Hub.
|
|
||||||
|
|
||||||
For a model already downloaded by a failed upload attempt, sanitize the extracted ONNX file and retry using the local model. Replace the job name in both paths:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uv run --with onnx python examples/meter-detection/source/sanitize_onnx.py \
|
|
||||||
build/qai-hub/meter-detection/qc-cli-YYYYMMDD-HHMMSS/source/extracted/model.onnx \
|
|
||||||
--output build/qai-hub/meter-detection/model.aihub.onnx
|
|
||||||
|
|
||||||
qc-cli ai-hub upload \
|
|
||||||
examples/meter-detection/data/aihub_calibration \
|
|
||||||
examples/meter-detection/data/inputs.npz \
|
|
||||||
--onnx-path build/qai-hub/meter-detection/model.aihub.onnx
|
|
||||||
```
|
|
||||||
|
|
||||||
Download the compiled artifact after the workflow completes:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
qc-cli ai-hub download --output build/qai-hub/meter-detection/model.tflite
|
|
||||||
```
|
|
||||||
|
|
||||||
## Training Hyperparameters
|
|
||||||
|
|
||||||
Values under `sagemaker.training.hyperparameters` are passed to `source/train.py` as command-line arguments.
|
|
||||||
|
|
||||||
| Name | Type | Default | Description |
|
|
||||||
|---|---:|---:|---|
|
|
||||||
| `model` | string | `yolo26n.pt` | Ultralytics model weights or model YAML. |
|
|
||||||
| `epochs` | int | `25` | Number of training epochs. |
|
|
||||||
| `imgsz` | int | `640` | Square training image size. |
|
|
||||||
| `batch` | int | `16` | Images per training batch. |
|
|
||||||
| `workers` | int | `2` | DataLoader worker count. |
|
|
||||||
| `patience` | int | `20` | Early stopping patience. |
|
|
||||||
| `device` | string | auto | Optional Ultralytics device value such as `0` or `cpu`. |
|
|
||||||
| `data-yaml` | string | auto | Optional path to `data.yaml`; normally discovered from the uploaded dataset root. |
|
|
||||||
| `dataset-dir` | string | `SM_CHANNEL_TRAIN` | Uploaded dataset root mounted by SageMaker. |
|
|
||||||
|
|
||||||
Do not set `dataset-dir` or `model-dir` in normal SageMaker runs. SageMaker sets those automatically through `SM_CHANNEL_TRAIN` and `SM_MODEL_DIR`.
|
|
||||||
@@ -1,92 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""Prepare Qualcomm AI Hub calibration and validation inputs for the meter detector."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png"}
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
|
||||||
parser = argparse.ArgumentParser(description=__doc__)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dataset-dir",
|
|
||||||
type=Path,
|
|
||||||
default=Path("examples/meter-detection/data/electric-meter-detection"),
|
|
||||||
help="Root of the extracted Roboflow dataset.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--calibration-dir",
|
|
||||||
type=Path,
|
|
||||||
default=Path("examples/meter-detection/data/aihub_calibration"),
|
|
||||||
help="Directory where .npy calibration samples will be written.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--input-file",
|
|
||||||
type=Path,
|
|
||||||
default=Path("examples/meter-detection/data/inputs.npz"),
|
|
||||||
help="Validation .npz input file for qc-cli ai-hub validate.",
|
|
||||||
)
|
|
||||||
parser.add_argument("--input-name", default="images", help="ONNX input name.")
|
|
||||||
parser.add_argument("--image-size", type=int, default=640, help="Square image size used for ONNX export.")
|
|
||||||
parser.add_argument("--samples", type=int, default=16, help="Number of calibration samples to write.")
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_image(path: Path, image_size: int) -> np.ndarray:
|
|
||||||
"""Apply Ultralytics-style letterboxing and produce an NCHW float32 tensor."""
|
|
||||||
with Image.open(path) as source:
|
|
||||||
image = source.convert("RGB")
|
|
||||||
|
|
||||||
scale = min(image_size / image.width, image_size / image.height)
|
|
||||||
resized_width = round(image.width * scale)
|
|
||||||
resized_height = round(image.height * scale)
|
|
||||||
image = image.resize((resized_width, resized_height), Image.Resampling.BILINEAR)
|
|
||||||
|
|
||||||
canvas = Image.new("RGB", (image_size, image_size), (114, 114, 114))
|
|
||||||
left = round((image_size - resized_width) / 2 - 0.1)
|
|
||||||
top = round((image_size - resized_height) / 2 - 0.1)
|
|
||||||
canvas.paste(image, (left, top))
|
|
||||||
|
|
||||||
array = np.asarray(canvas, dtype=np.float32) / 255.0
|
|
||||||
return np.transpose(array, (2, 0, 1))[None, ...].astype(np.float32)
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
args = parse_args()
|
|
||||||
if args.image_size < 1:
|
|
||||||
raise SystemExit("--image-size must be at least 1")
|
|
||||||
if args.samples < 1:
|
|
||||||
raise SystemExit("--samples must be at least 1")
|
|
||||||
|
|
||||||
images = sorted(
|
|
||||||
path
|
|
||||||
for path in args.dataset_dir.rglob("*")
|
|
||||||
if path.is_file() and path.suffix.lower() in IMAGE_EXTENSIONS and path.parent.name == "images"
|
|
||||||
)
|
|
||||||
if not images:
|
|
||||||
raise SystemExit(f"No images found under {args.dataset_dir}")
|
|
||||||
|
|
||||||
args.calibration_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
args.input_file.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
for stale_sample in args.calibration_dir.glob("sample_*.npy"):
|
|
||||||
stale_sample.unlink()
|
|
||||||
|
|
||||||
prepared: list[np.ndarray] = []
|
|
||||||
for index, image_path in enumerate(images[: args.samples]):
|
|
||||||
sample = preprocess_image(image_path, args.image_size)
|
|
||||||
np.save(args.calibration_dir / f"sample_{index:03d}.npy", sample)
|
|
||||||
prepared.append(sample)
|
|
||||||
|
|
||||||
np.savez(args.input_file, **{args.input_name: prepared[0]}) # pyright: ignore[reportArgumentType]
|
|
||||||
print(f"Wrote {len(prepared)} calibration samples to {args.calibration_dir}")
|
|
||||||
print(f"Wrote validation input to {args.input_file}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
ultralytics>=8.3.0
|
|
||||||
pyyaml>=6.0.3
|
|
||||||
onnx>=1.16.0
|
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""Remove ONNX value_info entries that duplicate graph inputs or outputs."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import onnx # type: ignore[reportMissingImports]
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_onnx(path: Path, output_path: Path | None = None) -> Path:
|
|
||||||
model = onnx.load(path)
|
|
||||||
io_names = {value.name for value in (*model.graph.input, *model.graph.output)}
|
|
||||||
retained_value_info = [value for value in model.graph.value_info if value.name not in io_names]
|
|
||||||
|
|
||||||
destination = output_path or path
|
|
||||||
if len(retained_value_info) != len(model.graph.value_info):
|
|
||||||
del model.graph.value_info[:]
|
|
||||||
model.graph.value_info.extend(retained_value_info)
|
|
||||||
|
|
||||||
destination.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
onnx.save(model, destination)
|
|
||||||
return destination
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
parser = argparse.ArgumentParser(description=__doc__)
|
|
||||||
parser.add_argument("onnx_path", type=Path)
|
|
||||||
parser.add_argument("--output", type=Path)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
written = sanitize_onnx(args.onnx_path, args.output)
|
|
||||||
print(f"Saved sanitized ONNX model to {written}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,128 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""SageMaker entry point for YOLO electric meter detection training."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
from sanitize_onnx import sanitize_onnx
|
|
||||||
from training_metrics import write_training_metrics
|
|
||||||
from ultralytics import YOLO # type: ignore[reportMissingImports]
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--model", default="yolo26n.pt")
|
|
||||||
parser.add_argument("--epochs", type=int, default=25)
|
|
||||||
parser.add_argument("--imgsz", type=int, default=640)
|
|
||||||
parser.add_argument("--batch", type=int, default=16)
|
|
||||||
parser.add_argument("--workers", type=int, default=2)
|
|
||||||
parser.add_argument("--patience", type=int, default=20)
|
|
||||||
parser.add_argument("--device", default=None)
|
|
||||||
parser.add_argument("--data-yaml", default=None)
|
|
||||||
parser.add_argument("--dataset-dir", default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
|
|
||||||
parser.add_argument("--train-dir", dest="dataset_dir", help=argparse.SUPPRESS)
|
|
||||||
parser.add_argument("--model-dir", default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def find_data_yaml(dataset_dir: Path, explicit_path: str | None) -> Path:
|
|
||||||
if explicit_path:
|
|
||||||
data_yaml = Path(explicit_path)
|
|
||||||
if data_yaml.is_file():
|
|
||||||
return data_yaml
|
|
||||||
raise FileNotFoundError(f"Configured data.yaml does not exist: {data_yaml}")
|
|
||||||
|
|
||||||
matches = sorted(dataset_dir.rglob("data.yaml"))
|
|
||||||
if not matches:
|
|
||||||
raise FileNotFoundError(f"Could not find data.yaml under {dataset_dir}")
|
|
||||||
if len(matches) > 1:
|
|
||||||
print(f"Found multiple data.yaml files; using {matches[0]}")
|
|
||||||
return matches[0]
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_data_yaml(data_yaml: Path) -> Path:
|
|
||||||
"""Write a SageMaker-local data file rooted at the uploaded dataset."""
|
|
||||||
dataset_root = data_yaml.parent
|
|
||||||
data = yaml.safe_load(data_yaml.read_text(encoding="utf-8"))
|
|
||||||
if not isinstance(data, dict):
|
|
||||||
raise ValueError(f"Expected a mapping in {data_yaml}")
|
|
||||||
|
|
||||||
normalized = dict(data)
|
|
||||||
normalized["path"] = str(dataset_root)
|
|
||||||
if "val" not in normalized and "valid" in normalized:
|
|
||||||
normalized["val"] = normalized.pop("valid")
|
|
||||||
|
|
||||||
prepared_path = dataset_root / "data.sagemaker.yaml"
|
|
||||||
prepared_path.write_text(yaml.safe_dump(normalized, sort_keys=False), encoding="utf-8")
|
|
||||||
print(f"Prepared dataset config: {prepared_path}")
|
|
||||||
return prepared_path
|
|
||||||
|
|
||||||
|
|
||||||
def copy_if_exists(source: Path, destination: Path) -> None:
|
|
||||||
if source.exists():
|
|
||||||
shutil.copy2(source, destination)
|
|
||||||
print(f"Saved {destination}")
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
args = parse_args()
|
|
||||||
dataset_dir = Path(args.dataset_dir)
|
|
||||||
model_dir = Path(args.model_dir)
|
|
||||||
model_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
data_yaml = prepare_data_yaml(find_data_yaml(dataset_dir, args.data_yaml))
|
|
||||||
model = YOLO(args.model)
|
|
||||||
|
|
||||||
train_kwargs: dict[str, Any] = {
|
|
||||||
"data": str(data_yaml),
|
|
||||||
"epochs": args.epochs,
|
|
||||||
"imgsz": args.imgsz,
|
|
||||||
"batch": args.batch,
|
|
||||||
"workers": args.workers,
|
|
||||||
"patience": args.patience,
|
|
||||||
"project": str(model_dir / "runs"),
|
|
||||||
"name": "train",
|
|
||||||
"exist_ok": True,
|
|
||||||
}
|
|
||||||
if args.device:
|
|
||||||
train_kwargs["device"] = args.device
|
|
||||||
|
|
||||||
results = model.train(**train_kwargs)
|
|
||||||
save_dir = Path(results.save_dir)
|
|
||||||
best_pt = save_dir / "weights" / "best.pt"
|
|
||||||
last_pt = save_dir / "weights" / "last.pt"
|
|
||||||
trained_weights = best_pt if best_pt.exists() else last_pt
|
|
||||||
if not trained_weights.exists():
|
|
||||||
raise FileNotFoundError(f"Could not find trained weights in {save_dir / 'weights'}")
|
|
||||||
|
|
||||||
write_training_metrics(save_dir / "results.csv", model_dir / "training_metrics.json")
|
|
||||||
copy_if_exists(trained_weights, model_dir / "best.pt")
|
|
||||||
trained_model = YOLO(str(trained_weights))
|
|
||||||
onnx_path = Path(trained_model.export(format="onnx", imgsz=args.imgsz))
|
|
||||||
saved_onnx_path = sanitize_onnx(onnx_path, model_dir / "model.onnx")
|
|
||||||
print(f"Saved {saved_onnx_path}")
|
|
||||||
|
|
||||||
metrics = {
|
|
||||||
"model": args.model,
|
|
||||||
"epochs": args.epochs,
|
|
||||||
"imgsz": args.imgsz,
|
|
||||||
"batch": args.batch,
|
|
||||||
"workers": args.workers,
|
|
||||||
"patience": args.patience,
|
|
||||||
"data_yaml": str(data_yaml),
|
|
||||||
"weights": str(trained_weights),
|
|
||||||
"onnx": str(saved_onnx_path),
|
|
||||||
}
|
|
||||||
(model_dir / "metrics.json").write_text(json.dumps(metrics, indent=2), encoding="utf-8")
|
|
||||||
print(f"Saved model artifacts to {model_dir}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,82 +0,0 @@
|
|||||||
import csv
|
|
||||||
import json
|
|
||||||
import math
|
|
||||||
import re
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
METRIC_NAMES = {
|
|
||||||
"metrics/precision(B)": "val.precision",
|
|
||||||
"metrics/recall(B)": "val.recall",
|
|
||||||
"metrics/mAP50(B)": "val.map50",
|
|
||||||
"metrics/mAP50-95(B)": "val.map50_95",
|
|
||||||
"train/box_loss": "train.box_loss",
|
|
||||||
"train/cls_loss": "train.cls_loss",
|
|
||||||
"train/dfl_loss": "train.dfl_loss",
|
|
||||||
"val/box_loss": "val.box_loss",
|
|
||||||
"val/cls_loss": "val.cls_loss",
|
|
||||||
"val/dfl_loss": "val.dfl_loss",
|
|
||||||
"time": "train.elapsed_seconds",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def write_training_metrics(results_csv: Path, destination: Path) -> None:
|
|
||||||
steps = _read_metric_steps(results_csv)
|
|
||||||
summary = _build_summary(steps)
|
|
||||||
payload = {
|
|
||||||
"schema_version": 1,
|
|
||||||
"steps": steps,
|
|
||||||
"summary": summary,
|
|
||||||
}
|
|
||||||
destination.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
|
||||||
print(f"Saved {destination}")
|
|
||||||
|
|
||||||
|
|
||||||
def _read_metric_steps(results_csv: Path) -> list[dict[str, Any]]:
|
|
||||||
if not results_csv.is_file():
|
|
||||||
raise FileNotFoundError(f"Could not find Ultralytics metrics history: {results_csv}")
|
|
||||||
|
|
||||||
steps: list[dict[str, Any]] = []
|
|
||||||
with results_csv.open(newline="", encoding="utf-8") as csv_file:
|
|
||||||
for row_index, raw_row in enumerate(csv.DictReader(csv_file)):
|
|
||||||
row = {str(key).strip(): value for key, value in raw_row.items()}
|
|
||||||
raw_epoch = row.pop("epoch", row_index)
|
|
||||||
step = int(float(raw_epoch))
|
|
||||||
metrics: dict[str, float] = {}
|
|
||||||
for source_name, raw_value in row.items():
|
|
||||||
if raw_value is None or not raw_value.strip():
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
value = float(raw_value)
|
|
||||||
except ValueError:
|
|
||||||
continue
|
|
||||||
if math.isfinite(value):
|
|
||||||
metrics[METRIC_NAMES.get(source_name, _normalize_metric_name(source_name))] = value
|
|
||||||
steps.append({"step": step, "metrics": metrics})
|
|
||||||
return steps
|
|
||||||
|
|
||||||
|
|
||||||
def _build_summary(steps: list[dict[str, Any]]) -> dict[str, float]:
|
|
||||||
if not steps:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
summary: dict[str, float] = {}
|
|
||||||
final_step = steps[-1]
|
|
||||||
summary["summary.final_epoch"] = float(final_step["step"])
|
|
||||||
for name, value in final_step["metrics"].items():
|
|
||||||
summary[f"summary.final.{name}"] = value
|
|
||||||
|
|
||||||
scored_steps = [step for step in steps if "val.map50_95" in step["metrics"]]
|
|
||||||
if scored_steps:
|
|
||||||
best_step = max(scored_steps, key=lambda step: step["metrics"]["val.map50_95"])
|
|
||||||
summary["summary.best_epoch"] = float(best_step["step"])
|
|
||||||
summary["summary.best_val.map50_95"] = best_step["metrics"]["val.map50_95"]
|
|
||||||
if "val.map50" in best_step["metrics"]:
|
|
||||||
summary["summary.best_val.map50"] = best_step["metrics"]["val.map50"]
|
|
||||||
return summary
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_metric_name(name: str) -> str:
|
|
||||||
normalized = name.replace("/", ".")
|
|
||||||
normalized = re.sub(r"[^A-Za-z0-9_.-]+", "_", normalized)
|
|
||||||
return normalized.strip("._") or "unnamed"
|
|
||||||
89
examples/training/README.md
Normal file
89
examples/training/README.md
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
# SageMaker Training Example
|
||||||
|
|
||||||
|
This example downloads a small image-classification dataset, uploads it through `qc-cli`, and submits a live SageMaker training job.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
- AWS credentials configured for the profile in `config.yaml`
|
||||||
|
- Infrastructure already deployed with `qc-cli infra setup`
|
||||||
|
- `config.yaml` updated with:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
s3:
|
||||||
|
bucket: your-bucket-name
|
||||||
|
|
||||||
|
sagemaker:
|
||||||
|
training:
|
||||||
|
image_uri: 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.6-cpu-py312-ubuntu22.04-sagemaker-v1
|
||||||
|
instance_type: ml.m4.xlarge
|
||||||
|
instance_count: 1
|
||||||
|
source_dir: examples/training/source
|
||||||
|
entry_point: train.py
|
||||||
|
hyperparameters:
|
||||||
|
epochs: 1
|
||||||
|
batch-size: 32
|
||||||
|
learning-rate: 0.001
|
||||||
|
image-size: 160
|
||||||
|
validation-split: 0.2
|
||||||
|
```
|
||||||
|
|
||||||
|
## Training Hyperparameters
|
||||||
|
|
||||||
|
Values under `sagemaker.training.hyperparameters` are passed to the training entry point as command-line arguments. For this example, they map to arguments defined in [source/train.py](source/train.py).
|
||||||
|
|
||||||
|
Supported by this example:
|
||||||
|
|
||||||
|
| Name | Type | Default | Description |
|
||||||
|
|---|---:|---:|---|
|
||||||
|
| `epochs` | int | `1` | Number of training epochs. |
|
||||||
|
| `batch-size` | int | `32` | Images per training batch. |
|
||||||
|
| `learning-rate` | float | `0.001` | Adam optimizer learning rate. |
|
||||||
|
| `image-size` | int | `160` | Resize images to square `image-size x image-size`. |
|
||||||
|
| `validation-split` | float | `0.2` | Fraction of data used for validation. |
|
||||||
|
| `max-samples` | int | `0` | Optional cap for smoke tests; `0` means use all images. |
|
||||||
|
| `seed` | int | `13` | Random seed for reproducible splitting. |
|
||||||
|
| `num-workers` | int | `2` | DataLoader worker count. |
|
||||||
|
|
||||||
|
Do not set `train-dir` or `model-dir` in normal SageMaker runs. SageMaker sets those automatically through `SM_CHANNEL_TRAIN` and `SM_MODEL_DIR`.
|
||||||
|
|
||||||
|
## 1. Download The Dataset
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash examples/training/download_flower_photos.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
This creates:
|
||||||
|
|
||||||
|
```text
|
||||||
|
examples/training/data/flower_photos_sagemaker/
|
||||||
|
daisy/
|
||||||
|
dandelion/
|
||||||
|
roses/
|
||||||
|
sunflowers/
|
||||||
|
tulips/
|
||||||
|
```
|
||||||
|
|
||||||
|
## 2. Run Training
|
||||||
|
|
||||||
|
Run the training script and wait until it finishes:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash examples/training/run_training.sh --config config.yaml --wait
|
||||||
|
```
|
||||||
|
|
||||||
|
Use a dataset that is already uploaded to `s3.data_prefix`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash examples/training/run_training.sh \
|
||||||
|
--config config.yaml \
|
||||||
|
--skip-upload \
|
||||||
|
--wait
|
||||||
|
```
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- The default dataset path is `examples/training/data/flower_photos_sagemaker`.
|
||||||
|
- Uploaded data uses the `s3.bucket` and `s3.data_prefix` values from `config.yaml`.
|
||||||
|
- Training artifacts are written under `s3://<bucket>/<model_prefix>/`.
|
||||||
|
- The SageMaker `model.tar.gz` contains `model.onnx`, `model.pt`, `class_to_idx.json`, and `metrics.json`.
|
||||||
|
- SageMaker packages `examples/training/source`, installs `requirements.txt`, and runs `train.py`.
|
||||||
40
examples/training/download_flower_photos.sh
Executable file
40
examples/training/download_flower_photos.sh
Executable file
@@ -0,0 +1,40 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
DATASET_URL="https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
|
||||||
|
DEST_DIR="${1:-examples/training/data}"
|
||||||
|
ARCHIVE_PATH="${DEST_DIR}/flower_photos.tgz"
|
||||||
|
RAW_DATASET_DIR="${DEST_DIR}/flower_photos"
|
||||||
|
DATASET_DIR="${DEST_DIR}/flower_photos_sagemaker"
|
||||||
|
CLASS_NAMES=("daisy" "dandelion" "roses" "sunflowers" "tulips")
|
||||||
|
|
||||||
|
mkdir -p "${DEST_DIR}"
|
||||||
|
|
||||||
|
if [[ -d "${DATASET_DIR}" ]]; then
|
||||||
|
echo "Dataset already exists: ${DATASET_DIR}"
|
||||||
|
echo "Use this path with run_training.py:"
|
||||||
|
echo " ${DATASET_DIR}"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Downloading TensorFlow flower_photos dataset..."
|
||||||
|
if command -v curl >/dev/null 2>&1; then
|
||||||
|
curl -L "${DATASET_URL}" -o "${ARCHIVE_PATH}"
|
||||||
|
elif command -v wget >/dev/null 2>&1; then
|
||||||
|
wget -O "${ARCHIVE_PATH}" "${DATASET_URL}"
|
||||||
|
else
|
||||||
|
echo "Either curl or wget is required." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Extracting dataset..."
|
||||||
|
tar -xzf "${ARCHIVE_PATH}" -C "${DEST_DIR}"
|
||||||
|
|
||||||
|
echo "Preparing SageMaker directory layout..."
|
||||||
|
mkdir -p "${DATASET_DIR}"
|
||||||
|
for class_name in "${CLASS_NAMES[@]}"; do
|
||||||
|
cp -R "${RAW_DATASET_DIR}/${class_name}" "${DATASET_DIR}/${class_name}"
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "Dataset ready: ${DATASET_DIR}"
|
||||||
|
find "${DATASET_DIR}" -mindepth 1 -maxdepth 1 -type d -print | sort
|
||||||
112
examples/training/run_training.sh
Executable file
112
examples/training/run_training.sh
Executable file
@@ -0,0 +1,112 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
CONFIG_PATH="config.yaml"
|
||||||
|
DATASET_DIR="examples/training/data/flower_photos_sagemaker"
|
||||||
|
WAIT=false
|
||||||
|
SKIP_UPLOAD=false
|
||||||
|
POLL_SECONDS=60
|
||||||
|
|
||||||
|
usage() {
|
||||||
|
cat <<EOF
|
||||||
|
Usage: $0 [options]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--config PATH Path to qc-cli config file. Default: config.yaml
|
||||||
|
--dataset-dir PATH Dataset directory to upload. Default: ${DATASET_DIR}
|
||||||
|
--skip-upload Train against data already uploaded to s3.data_prefix.
|
||||||
|
--wait Poll until training completes.
|
||||||
|
-h, --help Show this help.
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
while [[ $# -gt 0 ]]; do
|
||||||
|
case "$1" in
|
||||||
|
--config)
|
||||||
|
CONFIG_PATH="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--dataset-dir)
|
||||||
|
DATASET_DIR="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--skip-upload)
|
||||||
|
SKIP_UPLOAD=true
|
||||||
|
shift
|
||||||
|
;;
|
||||||
|
--wait)
|
||||||
|
WAIT=true
|
||||||
|
shift
|
||||||
|
;;
|
||||||
|
-h|--help)
|
||||||
|
usage
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unknown option: $1" >&2
|
||||||
|
usage >&2
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
if [[ ! -f "${CONFIG_PATH}" ]]; then
|
||||||
|
echo "Config not found: ${CONFIG_PATH}" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "${SKIP_UPLOAD}" == false && ! -d "${DATASET_DIR}" ]]; then
|
||||||
|
echo "Dataset not found: ${DATASET_DIR}" >&2
|
||||||
|
echo "Run: bash examples/training/download_flower_photos.sh" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
run() {
|
||||||
|
echo "+ $*"
|
||||||
|
"$@"
|
||||||
|
}
|
||||||
|
|
||||||
|
run uv run qc-cli infra status --config "${CONFIG_PATH}"
|
||||||
|
|
||||||
|
if [[ "${SKIP_UPLOAD}" == false ]]; then
|
||||||
|
run uv run qc-cli upload "${DATASET_DIR}" --config "${CONFIG_PATH}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
TRAIN_OUTPUT_FILE="$(mktemp)"
|
||||||
|
trap 'rm -f "${TRAIN_OUTPUT_FILE}"' EXIT
|
||||||
|
run uv run qc-cli train start --config "${CONFIG_PATH}" | tee "${TRAIN_OUTPUT_FILE}"
|
||||||
|
|
||||||
|
JOB_NAME="$(grep -Eo 'qc-cli-[0-9]{8}-[0-9]{6}' "${TRAIN_OUTPUT_FILE}" | tail -n 1)"
|
||||||
|
if [[ -z "${JOB_NAME}" ]]; then
|
||||||
|
echo "Could not find training job name in qc-cli output." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Submitted SageMaker training job: ${JOB_NAME}"
|
||||||
|
|
||||||
|
if [[ "${WAIT}" == false ]]; then
|
||||||
|
run uv run qc-cli train status "${JOB_NAME}" --config "${CONFIG_PATH}"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
while true; do
|
||||||
|
STATUS_OUTPUT="$(uv run qc-cli train status "${JOB_NAME}" --config "${CONFIG_PATH}")"
|
||||||
|
echo "${STATUS_OUTPUT}"
|
||||||
|
|
||||||
|
if printf '%s\n' "${STATUS_OUTPUT}" | grep -q 'Status:.*Completed'; then
|
||||||
|
echo "Training completed successfully."
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
if printf '%s\n' "${STATUS_OUTPUT}" | grep -q 'Status:.*Failed'; then
|
||||||
|
echo "Training failed." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if printf '%s\n' "${STATUS_OUTPUT}" | grep -q 'Status:.*Stopped'; then
|
||||||
|
echo "Training stopped." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
sleep "${POLL_SECONDS}"
|
||||||
|
done
|
||||||
1
examples/training/source/requirements.txt
Normal file
1
examples/training/source/requirements.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
onnx==1.21.0
|
||||||
188
examples/training/source/train.py
Normal file
188
examples/training/source/train.py
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""SageMaker entry point for CPU image-classification training."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import DataLoader, Subset, random_split
|
||||||
|
from torchvision import datasets, transforms
|
||||||
|
|
||||||
|
|
||||||
|
class SmallImageClassifier(nn.Module):
|
||||||
|
def __init__(self, class_count: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.features = nn.Sequential(
|
||||||
|
nn.Conv2d(3, 16, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(2),
|
||||||
|
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(2),
|
||||||
|
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(2),
|
||||||
|
nn.AdaptiveAvgPool2d((1, 1)),
|
||||||
|
)
|
||||||
|
self.classifier = nn.Linear(64, class_count)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.features(x)
|
||||||
|
x = torch.flatten(x, 1)
|
||||||
|
return self.classifier(x)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--epochs", type=int, default=1)
|
||||||
|
parser.add_argument("--batch-size", type=int, default=32)
|
||||||
|
parser.add_argument("--learning-rate", type=float, default=0.001)
|
||||||
|
parser.add_argument("--image-size", type=int, default=160)
|
||||||
|
parser.add_argument("--validation-split", type=float, default=0.2)
|
||||||
|
parser.add_argument("--max-samples", type=int, default=0)
|
||||||
|
parser.add_argument("--seed", type=int, default=13)
|
||||||
|
parser.add_argument("--num-workers", type=int, default=2)
|
||||||
|
parser.add_argument("--train-dir", default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
|
||||||
|
parser.add_argument("--model-dir", default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def build_datasets(args: argparse.Namespace) -> tuple[Subset, Subset, dict[str, int]]:
|
||||||
|
transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Resize((args.image_size, args.image_size)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
dataset = datasets.ImageFolder(args.train_dir, transform=transform)
|
||||||
|
if len(dataset.classes) < 2:
|
||||||
|
raise ValueError(f"Expected at least two classes in {args.train_dir}. Found: {dataset.classes}")
|
||||||
|
|
||||||
|
if args.max_samples > 0 and args.max_samples < len(dataset):
|
||||||
|
indices = list(range(len(dataset)))
|
||||||
|
random.Random(args.seed).shuffle(indices)
|
||||||
|
dataset = Subset(dataset, indices[: args.max_samples])
|
||||||
|
|
||||||
|
validation_size = max(1, int(len(dataset) * args.validation_split))
|
||||||
|
train_size = len(dataset) - validation_size
|
||||||
|
if train_size < 1:
|
||||||
|
raise ValueError("Not enough images to create a train/validation split.")
|
||||||
|
|
||||||
|
generator = torch.Generator().manual_seed(args.seed)
|
||||||
|
train_dataset, validation_dataset = random_split(dataset, [train_size, validation_size], generator=generator)
|
||||||
|
return train_dataset, validation_dataset, getattr(dataset, "dataset", dataset).class_to_idx
|
||||||
|
|
||||||
|
|
||||||
|
def run_epoch(
|
||||||
|
model: nn.Module,
|
||||||
|
data_loader: DataLoader,
|
||||||
|
criterion: nn.Module,
|
||||||
|
optimizer: torch.optim.Optimizer | None,
|
||||||
|
device: torch.device,
|
||||||
|
) -> tuple[float, float]:
|
||||||
|
training = optimizer is not None
|
||||||
|
model.train(training)
|
||||||
|
|
||||||
|
total_loss = 0.0
|
||||||
|
total_correct = 0
|
||||||
|
total_examples = 0
|
||||||
|
|
||||||
|
for images, labels in data_loader:
|
||||||
|
images = images.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
|
||||||
|
with torch.set_grad_enabled(training):
|
||||||
|
logits = model(images)
|
||||||
|
loss = criterion(logits, labels)
|
||||||
|
|
||||||
|
if training:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
total_loss += loss.item() * images.size(0)
|
||||||
|
total_correct += (logits.argmax(dim=1) == labels).sum().item()
|
||||||
|
total_examples += images.size(0)
|
||||||
|
|
||||||
|
return total_loss / total_examples, total_correct / total_examples
|
||||||
|
|
||||||
|
|
||||||
|
def export_onnx(model: nn.Module, model_dir: Path, image_size: int) -> None:
|
||||||
|
model.eval()
|
||||||
|
dummy_input = torch.randn(1, 3, image_size, image_size)
|
||||||
|
torch.onnx.export(
|
||||||
|
model,
|
||||||
|
dummy_input,
|
||||||
|
model_dir / "model.onnx",
|
||||||
|
export_params=True,
|
||||||
|
opset_version=17,
|
||||||
|
do_constant_folding=True,
|
||||||
|
input_names=["input"],
|
||||||
|
output_names=["logits"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
|
||||||
|
train_dataset, validation_dataset, class_to_idx = build_datasets(args)
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
)
|
||||||
|
validation_loader = DataLoader(
|
||||||
|
validation_dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
model = SmallImageClassifier(class_count=len(class_to_idx)).to(device)
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
|
||||||
|
|
||||||
|
print(f"Training on {device}. Classes: {sorted(class_to_idx)}")
|
||||||
|
metrics = []
|
||||||
|
for epoch in range(1, args.epochs + 1):
|
||||||
|
train_loss, train_accuracy = run_epoch(model, train_loader, criterion, optimizer, device)
|
||||||
|
validation_loss, validation_accuracy = run_epoch(model, validation_loader, criterion, None, device)
|
||||||
|
epoch_metrics = {
|
||||||
|
"epoch": epoch,
|
||||||
|
"train_loss": train_loss,
|
||||||
|
"train_accuracy": train_accuracy,
|
||||||
|
"validation_loss": validation_loss,
|
||||||
|
"validation_accuracy": validation_accuracy,
|
||||||
|
}
|
||||||
|
metrics.append(epoch_metrics)
|
||||||
|
print(json.dumps(epoch_metrics, sort_keys=True))
|
||||||
|
|
||||||
|
model_dir = Path(args.model_dir)
|
||||||
|
model_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
torch.save(
|
||||||
|
{
|
||||||
|
"model_state_dict": model.cpu().state_dict(),
|
||||||
|
"class_to_idx": class_to_idx,
|
||||||
|
"image_size": args.image_size,
|
||||||
|
},
|
||||||
|
model_dir / "model.pt",
|
||||||
|
)
|
||||||
|
export_onnx(model, model_dir, args.image_size)
|
||||||
|
(model_dir / "class_to_idx.json").write_text(json.dumps(class_to_idx, indent=2), encoding="utf-8")
|
||||||
|
(model_dir / "metrics.json").write_text(json.dumps(metrics, indent=2), encoding="utf-8")
|
||||||
|
print(f"Saved model artifacts to {model_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -5,7 +5,7 @@ build-backend = "hatchling.build"
|
|||||||
[project]
|
[project]
|
||||||
name = "qc-cli"
|
name = "qc-cli"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
description = "CLI for training and deploying models for Qualcomm AI Hub"
|
description = "CLI for SageMaker ONNX training and Qualcomm AI Hub optimization"
|
||||||
requires-python = ">=3.13"
|
requires-python = ">=3.13"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aws-cdk-lib>=2.180.0",
|
"aws-cdk-lib>=2.180.0",
|
||||||
@@ -29,6 +29,8 @@ packages = ["src"]
|
|||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = [
|
||||||
"boto3-stubs[iam,s3,sagemaker]",
|
"boto3-stubs[iam,s3,sagemaker]",
|
||||||
|
"pytest>=8.0",
|
||||||
|
"pytest-mock>=3.12",
|
||||||
"pyright>=1.1.409",
|
"pyright>=1.1.409",
|
||||||
"types-PyYAML",
|
"types-PyYAML",
|
||||||
"ruff>=0.4",
|
"ruff>=0.4",
|
||||||
|
|||||||
@@ -1,6 +1,3 @@
|
|||||||
import os
|
|
||||||
from collections.abc import Generator
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
@@ -37,38 +34,3 @@ def create_presigned_tracking_server_url(region: str, profile: str, name: str) -
|
|||||||
client = boto3.Session(profile_name=profile, region_name=region).client("sagemaker")
|
client = boto3.Session(profile_name=profile, region_name=region).client("sagemaker")
|
||||||
response = client.create_presigned_mlflow_tracking_server_url(TrackingServerName=name)
|
response = client.create_presigned_mlflow_tracking_server_url(TrackingServerName=name)
|
||||||
return str(response["AuthorizedUrl"])
|
return str(response["AuthorizedUrl"])
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def tracking_auth_env(profile: str, region: str) -> Generator[None]:
|
|
||||||
credentials = boto3.Session(profile_name=profile, region_name=region).get_credentials()
|
|
||||||
if credentials is None:
|
|
||||||
raise RuntimeError(f"AWS credentials could not be resolved for profile '{profile}'.")
|
|
||||||
|
|
||||||
frozen_credentials = credentials.get_frozen_credentials()
|
|
||||||
if not frozen_credentials.access_key or not frozen_credentials.secret_key:
|
|
||||||
raise RuntimeError(f"AWS credentials are incomplete for profile '{profile}'.")
|
|
||||||
|
|
||||||
env_updates = {
|
|
||||||
"AWS_PROFILE": profile,
|
|
||||||
"AWS_DEFAULT_REGION": region,
|
|
||||||
"AWS_REGION": region,
|
|
||||||
"AWS_ACCESS_KEY_ID": frozen_credentials.access_key,
|
|
||||||
"AWS_SECRET_ACCESS_KEY": frozen_credentials.secret_key,
|
|
||||||
}
|
|
||||||
if frozen_credentials.token:
|
|
||||||
env_updates["AWS_SESSION_TOKEN"] = frozen_credentials.token
|
|
||||||
|
|
||||||
restore_keys = set(env_updates) | {"AWS_SESSION_TOKEN"}
|
|
||||||
previous_env = {key: os.environ.get(key) for key in restore_keys}
|
|
||||||
try:
|
|
||||||
os.environ.update(env_updates)
|
|
||||||
if not frozen_credentials.token:
|
|
||||||
os.environ.pop("AWS_SESSION_TOKEN", None)
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
for key, value in previous_env.items():
|
|
||||||
if value is None:
|
|
||||||
os.environ.pop(key, None)
|
|
||||||
else:
|
|
||||||
os.environ[key] = value
|
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
"""Cloud provider adapters."""
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
from contextlib import AbstractContextManager
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Protocol
|
|
||||||
|
|
||||||
from src.aws import mlflow as aws_mlflow
|
|
||||||
from src.config import Config
|
|
||||||
|
|
||||||
|
|
||||||
class MlflowTrackingBackend(Protocol):
|
|
||||||
@property
|
|
||||||
def provider_name(self) -> str: ...
|
|
||||||
|
|
||||||
@property
|
|
||||||
def profile(self) -> str: ...
|
|
||||||
|
|
||||||
@property
|
|
||||||
def region(self) -> str: ...
|
|
||||||
|
|
||||||
def get_tracking_uri(self, tracking_server_name: str) -> str: ...
|
|
||||||
|
|
||||||
def auth_env(self) -> AbstractContextManager[None]: ...
|
|
||||||
|
|
||||||
def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]: ...
|
|
||||||
|
|
||||||
def training_run_tags(self, training_job: Any) -> dict[str, Any]: ...
|
|
||||||
|
|
||||||
def training_status_params(self, training_job_status: Any) -> dict[str, Any]: ...
|
|
||||||
|
|
||||||
def model_version_tags(self, training_job_status: Any) -> dict[str, Any]: ...
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class AwsMlflowTrackingBackend:
|
|
||||||
profile: str
|
|
||||||
region: str
|
|
||||||
provider_name: str = "aws"
|
|
||||||
|
|
||||||
def get_tracking_uri(self, tracking_server_name: str) -> str:
|
|
||||||
return aws_mlflow.get_tracking_server_arn(self.region, self.profile, tracking_server_name)
|
|
||||||
|
|
||||||
def auth_env(self) -> AbstractContextManager[None]:
|
|
||||||
return aws_mlflow.tracking_auth_env(self.profile, self.region)
|
|
||||||
|
|
||||||
def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"provider.name": self.provider_name,
|
|
||||||
"provider.region": region,
|
|
||||||
"provider.profile": profile,
|
|
||||||
"sagemaker.role_arn": role_arn,
|
|
||||||
"sagemaker.job_name": training_job.job_name,
|
|
||||||
"sagemaker.training_image": training_job.image_uri,
|
|
||||||
"sagemaker.instance_type": training_job.instance_type,
|
|
||||||
"sagemaker.instance_count": training_job.instance_count,
|
|
||||||
"sagemaker.s3_train_uri": training_job.s3_train_uri,
|
|
||||||
"sagemaker.s3_output_path": training_job.s3_output_path,
|
|
||||||
"sagemaker.entry_point": training_job.entry_point,
|
|
||||||
"sagemaker.source_dir": training_job.source_dir,
|
|
||||||
}
|
|
||||||
|
|
||||||
def training_run_tags(self, training_job: Any) -> dict[str, Any]:
|
|
||||||
return {"sagemaker.job_name": training_job.job_name}
|
|
||||||
|
|
||||||
def training_status_params(self, training_job_status: Any) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"sagemaker.training_status": training_job_status.status,
|
|
||||||
"sagemaker.created_at": training_job_status.created,
|
|
||||||
"sagemaker.modified_at": training_job_status.modified,
|
|
||||||
"sagemaker.model_artifacts": training_job_status.model_artifacts,
|
|
||||||
"sagemaker.failure_reason": training_job_status.failure_reason,
|
|
||||||
}
|
|
||||||
|
|
||||||
def model_version_tags(self, training_job_status: Any) -> dict[str, Any]:
|
|
||||||
return {"sagemaker.job_name": training_job_status.name}
|
|
||||||
|
|
||||||
|
|
||||||
def mlflow_tracking_backend_from_config(cfg: Config) -> MlflowTrackingBackend:
|
|
||||||
return AwsMlflowTrackingBackend(profile=cfg.aws.profile, region=cfg.aws.region)
|
|
||||||
@@ -1,21 +1,18 @@
|
|||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import qai_hub.hub as hub
|
|
||||||
import typer
|
import typer
|
||||||
from qai_hub.client import Device
|
|
||||||
|
|
||||||
from src import state as state_ops
|
from src import state as state_ops
|
||||||
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
||||||
from src.config import Config
|
from src.config import Config
|
||||||
from src.qualcomm import aihub_jobs
|
from src.qualcomm import aihub_jobs
|
||||||
from src.qualcomm.artifacts import ResolvedOnnx, resolve_onnx
|
from src.qualcomm.artifacts import resolve_onnx
|
||||||
|
|
||||||
app = typer.Typer(help="Optimize, quantize, compile, validate, profile, and download models with Qualcomm Workbench")
|
app = typer.Typer(help="Quantize, compile, validate, profile, and download models with Qualcomm AI Hub")
|
||||||
|
|
||||||
_RUNTIME_EXTENSIONS = {
|
_RUNTIME_EXTENSIONS = {
|
||||||
"tflite": "tflite",
|
"tflite": "tflite",
|
||||||
@@ -25,19 +22,12 @@ _RUNTIME_EXTENSIONS = {
|
|||||||
|
|
||||||
|
|
||||||
class UploadStep(StrEnum):
|
class UploadStep(StrEnum):
|
||||||
optimize = "optimize"
|
|
||||||
quantize = "quantize"
|
quantize = "quantize"
|
||||||
compile = "compile"
|
compile = "compile"
|
||||||
validate = "validate"
|
validate = "validate"
|
||||||
profile = "profile"
|
profile = "profile"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class ResolvedModelSource:
|
|
||||||
model: str | Path
|
|
||||||
model_artifact: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def _input_specs(cfg: Config) -> dict[str, tuple[tuple[int, ...], str]]:
|
def _input_specs(cfg: Config) -> dict[str, tuple[tuple[int, ...], str]]:
|
||||||
specs = {name: (tuple(shape), dtype) for name, (shape, dtype) in cfg.aihub.input_specs.items()}
|
specs = {name: (tuple(shape), dtype) for name, (shape, dtype) in cfg.aihub.input_specs.items()}
|
||||||
if not specs:
|
if not specs:
|
||||||
@@ -109,105 +99,24 @@ def _model_id_or_state(config_path: str, model_id: str | None, *, quantized: boo
|
|||||||
return resolved
|
return resolved
|
||||||
|
|
||||||
|
|
||||||
def _resolve_model_source(
|
|
||||||
cfg: Config,
|
|
||||||
config_path: str,
|
|
||||||
*,
|
|
||||||
model_id: str | None = None,
|
|
||||||
previous_model_id: str | None = None,
|
|
||||||
from_job: str | None = None,
|
|
||||||
model_s3_uri: str | None = None,
|
|
||||||
onnx_path: str | None = None,
|
|
||||||
) -> ResolvedModelSource:
|
|
||||||
if model_id:
|
|
||||||
return ResolvedModelSource(model_id)
|
|
||||||
|
|
||||||
has_explicit_source = bool(from_job or model_s3_uri or onnx_path)
|
|
||||||
if previous_model_id and not has_explicit_source:
|
|
||||||
return ResolvedModelSource(previous_model_id)
|
|
||||||
|
|
||||||
resolved = _resolve_onnx_source(
|
|
||||||
cfg,
|
|
||||||
config_path,
|
|
||||||
from_job=from_job,
|
|
||||||
model_s3_uri=model_s3_uri,
|
|
||||||
onnx_path=onnx_path,
|
|
||||||
)
|
|
||||||
return ResolvedModelSource(resolved.onnx_path, resolved.model_artifact)
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_onnx_source(
|
|
||||||
cfg: Config,
|
|
||||||
config_path: str,
|
|
||||||
*,
|
|
||||||
from_job: str | None = None,
|
|
||||||
model_s3_uri: str | None = None,
|
|
||||||
onnx_path: str | None = None,
|
|
||||||
) -> ResolvedOnnx:
|
|
||||||
st = state_ops.store(config_path)
|
|
||||||
last_training_job = st.get_last_training_job()
|
|
||||||
saved_model_artifact = None
|
|
||||||
if not from_job and not model_s3_uri and not onnx_path and not last_training_job:
|
|
||||||
saved_model_artifact = st.get_last_model_artifact()
|
|
||||||
|
|
||||||
return resolve_onnx(
|
|
||||||
cfg=cfg,
|
|
||||||
output_dir=cfg.aihub.output_dir,
|
|
||||||
from_job=from_job,
|
|
||||||
model_s3_uri=model_s3_uri or saved_model_artifact,
|
|
||||||
onnx_path=onnx_path,
|
|
||||||
last_training_job=last_training_job,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _device_selector(device: Device) -> str:
|
|
||||||
parts: list[str] = []
|
|
||||||
if device.name:
|
|
||||||
parts.append(f"name={device.name!r}")
|
|
||||||
if device.os:
|
|
||||||
parts.append(f"os={device.os!r}")
|
|
||||||
if device.attributes:
|
|
||||||
parts.append(f"attributes={device.attributes!r}")
|
|
||||||
return ", ".join(parts) if parts else "empty selector"
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_device(cfg: Config) -> None:
|
|
||||||
device = cfg.aihub.device
|
|
||||||
try:
|
|
||||||
matches = hub.get_devices(name=device.name, os=device.os, attributes=device.attributes)
|
|
||||||
except Exception as e:
|
|
||||||
CONSOLE.print(f"[red]Unable to validate AI Hub device {_device_selector(device)}: {e}[/red]")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
if matches:
|
|
||||||
return
|
|
||||||
|
|
||||||
CONSOLE.print(f"[red]AI Hub device not found: {_device_selector(device)}[/red]")
|
|
||||||
CONSOLE.print("Run [bold]qai-hub list-devices[/bold] to see valid device names.")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def _quantize_step(
|
def _quantize_step(
|
||||||
cfg: Config,
|
cfg: Config,
|
||||||
config_path: str,
|
config_path: str,
|
||||||
calibration_path: Path,
|
calibration_path: Path,
|
||||||
*,
|
from_job: str | None,
|
||||||
model_id: str | None = None,
|
model_s3_uri: str | None,
|
||||||
from_job: str | None = None,
|
onnx_path: str | None,
|
||||||
model_s3_uri: str | None = None,
|
|
||||||
onnx_path: str | None = None,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
st = state_ops.store(config_path)
|
st = state_ops.store(config_path)
|
||||||
specs = _input_specs(cfg)
|
specs = _input_specs(cfg)
|
||||||
try:
|
try:
|
||||||
source = _resolve_model_source(
|
resolved = resolve_onnx(
|
||||||
cfg,
|
cfg=cfg,
|
||||||
config_path,
|
output_dir=cfg.aihub.output_dir,
|
||||||
model_id=model_id,
|
|
||||||
previous_model_id=st.get_last_optimized_model_id(),
|
|
||||||
from_job=from_job,
|
from_job=from_job,
|
||||||
model_s3_uri=model_s3_uri,
|
model_s3_uri=model_s3_uri or st.get_last_model_artifact(),
|
||||||
onnx_path=onnx_path,
|
onnx_path=onnx_path,
|
||||||
|
last_training_job=st.get_last_training_job(),
|
||||||
)
|
)
|
||||||
calibration_data = _load_calibration(calibration_path, specs)
|
calibration_data = _load_calibration(calibration_path, specs)
|
||||||
except (FileNotFoundError, ValueError) as e:
|
except (FileNotFoundError, ValueError) as e:
|
||||||
@@ -215,117 +124,72 @@ def _quantize_step(
|
|||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
hub_model = (
|
|
||||||
hub.upload_model(str(source.model), name=cfg.aihub.model_name)
|
|
||||||
if isinstance(source.model, Path)
|
|
||||||
else hub.get_model(source.model)
|
|
||||||
)
|
|
||||||
result = aihub_jobs.submit_quantize_job(
|
result = aihub_jobs.submit_quantize_job(
|
||||||
hub_model,
|
resolved.onnx_path,
|
||||||
calibration_data,
|
calibration_data,
|
||||||
cfg.aihub.quantize_options,
|
cfg.aihub.quantize_options,
|
||||||
job_name=_job_name(cfg, "quantize"),
|
job_name=_job_name(cfg, "quantize"),
|
||||||
|
model_name=cfg.aihub.model_name,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
CONSOLE.print(f"[red]AI Hub quantize failed: {e}[/red]")
|
CONSOLE.print(f"[red]AI Hub quantize failed: {e}[/red]")
|
||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
|
|
||||||
updates: dict[str, Any] = {
|
st.update(
|
||||||
"last_quantize_job_id": result["job_id"],
|
last_model_artifact=resolved.model_artifact,
|
||||||
"last_quantized_model_id": result["model_id"],
|
last_quantize_job_id=result["job_id"],
|
||||||
}
|
last_quantized_model_id=result["model_id"],
|
||||||
if source.model_artifact:
|
)
|
||||||
updates["last_model_artifact"] = source.model_artifact
|
|
||||||
st.update(**updates)
|
|
||||||
CONSOLE.print(f"[green]✓[/green] Quantize job: [bold]{result['job_id']}[/bold]")
|
CONSOLE.print(f"[green]✓[/green] Quantize job: [bold]{result['job_id']}[/bold]")
|
||||||
CONSOLE.print(f"[green]✓[/green] Quantized model: [bold]{result['model_id']}[/bold]")
|
CONSOLE.print(f"[green]✓[/green] Quantized model: [bold]{result['model_id']}[/bold]")
|
||||||
return str(result["model_id"])
|
return str(result["model_id"])
|
||||||
|
|
||||||
|
|
||||||
def _optimize_step(
|
|
||||||
cfg: Config,
|
|
||||||
config_path: str,
|
|
||||||
from_job: str | None,
|
|
||||||
model_s3_uri: str | None,
|
|
||||||
onnx_path: str | None,
|
|
||||||
) -> str:
|
|
||||||
st = state_ops.store(config_path)
|
|
||||||
_validate_device(cfg)
|
|
||||||
specs = _input_specs(cfg)
|
|
||||||
try:
|
|
||||||
source = _resolve_onnx_source(
|
|
||||||
cfg,
|
|
||||||
config_path,
|
|
||||||
from_job=from_job,
|
|
||||||
model_s3_uri=model_s3_uri,
|
|
||||||
onnx_path=onnx_path,
|
|
||||||
)
|
|
||||||
except (FileNotFoundError, ValueError) as e:
|
|
||||||
CONSOLE.print(f"[red]{e}[/red]")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
try:
|
|
||||||
hub_model = hub.upload_model(str(source.onnx_path), name=cfg.aihub.model_name)
|
|
||||||
result = aihub_jobs.submit_compile_job(
|
|
||||||
model=hub_model,
|
|
||||||
device=cfg.aihub.device,
|
|
||||||
input_specs=specs,
|
|
||||||
target_runtime="onnx",
|
|
||||||
job_name=_job_name(cfg, "optimize"),
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
CONSOLE.print(f"[red]AI Hub ONNX optimization failed: {e}[/red]")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
st.update(
|
|
||||||
last_model_artifact=source.model_artifact,
|
|
||||||
last_optimize_job_id=result["job_id"],
|
|
||||||
last_optimized_model_id=result["model_id"],
|
|
||||||
)
|
|
||||||
CONSOLE.print(f"[green]✓[/green] ONNX optimization job: [bold]{result['job_id']}[/bold]")
|
|
||||||
CONSOLE.print(f"[green]✓[/green] Optimized ONNX model: [bold]{result['model_id']}[/bold]")
|
|
||||||
return str(result["model_id"])
|
|
||||||
|
|
||||||
|
|
||||||
def _compile_step(
|
def _compile_step(
|
||||||
cfg: Config,
|
cfg: Config,
|
||||||
config_path: str,
|
config_path: str,
|
||||||
|
model_id: str | None,
|
||||||
|
from_job: str | None,
|
||||||
|
model_s3_uri: str | None,
|
||||||
|
onnx_path: str | None,
|
||||||
*,
|
*,
|
||||||
model_id: str | None = None,
|
prefer_quantized: bool,
|
||||||
from_job: str | None = None,
|
|
||||||
model_s3_uri: str | None = None,
|
|
||||||
onnx_path: str | None = None,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
st = state_ops.store(config_path)
|
st = state_ops.store(config_path)
|
||||||
_validate_device(cfg)
|
|
||||||
specs = _input_specs(cfg)
|
specs = _input_specs(cfg)
|
||||||
try:
|
|
||||||
source = _resolve_model_source(
|
model: Any
|
||||||
cfg,
|
model_artifact: str | None = None
|
||||||
config_path,
|
has_explicit_source = bool(from_job or model_s3_uri or onnx_path)
|
||||||
model_id=model_id,
|
if model_id:
|
||||||
previous_model_id=st.get_last_quantized_model_id(),
|
model = model_id
|
||||||
from_job=from_job,
|
elif prefer_quantized and not has_explicit_source and st.get_last_quantized_model_id():
|
||||||
model_s3_uri=model_s3_uri,
|
model = st.get_last_quantized_model_id()
|
||||||
onnx_path=onnx_path,
|
else:
|
||||||
)
|
try:
|
||||||
except (FileNotFoundError, ValueError) as e:
|
resolved = resolve_onnx(
|
||||||
CONSOLE.print(f"[red]{e}[/red]")
|
cfg=cfg,
|
||||||
raise typer.Exit(1)
|
output_dir=cfg.aihub.output_dir,
|
||||||
|
from_job=from_job,
|
||||||
|
model_s3_uri=model_s3_uri,
|
||||||
|
onnx_path=onnx_path,
|
||||||
|
last_training_job=st.get_last_training_job(),
|
||||||
|
)
|
||||||
|
except (FileNotFoundError, ValueError) as e:
|
||||||
|
CONSOLE.print(f"[red]{e}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
model = resolved.onnx_path
|
||||||
|
model_artifact = resolved.model_artifact
|
||||||
|
|
||||||
try:
|
try:
|
||||||
hub_model = (
|
|
||||||
hub.upload_model(str(source.model), name=cfg.aihub.model_name)
|
|
||||||
if isinstance(source.model, Path)
|
|
||||||
else hub.get_model(source.model)
|
|
||||||
)
|
|
||||||
result = aihub_jobs.submit_compile_job(
|
result = aihub_jobs.submit_compile_job(
|
||||||
model=hub_model,
|
model=model,
|
||||||
device=cfg.aihub.device,
|
device_name=cfg.aihub.device,
|
||||||
input_specs=specs,
|
input_specs=specs,
|
||||||
target_runtime=cfg.aihub.target_runtime,
|
target_runtime=cfg.aihub.target_runtime,
|
||||||
options=cfg.aihub.compile_options,
|
options=cfg.aihub.compile_options,
|
||||||
job_name=_job_name(cfg, "compile"),
|
job_name=_job_name(cfg, "compile"),
|
||||||
|
model_name=cfg.aihub.model_name if isinstance(model, Path) else None,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
CONSOLE.print(f"[red]AI Hub compile failed: {e}[/red]")
|
CONSOLE.print(f"[red]AI Hub compile failed: {e}[/red]")
|
||||||
@@ -335,8 +199,8 @@ def _compile_step(
|
|||||||
"last_compile_job_id": result["job_id"],
|
"last_compile_job_id": result["job_id"],
|
||||||
"last_compiled_model_id": result["model_id"],
|
"last_compiled_model_id": result["model_id"],
|
||||||
}
|
}
|
||||||
if source.model_artifact:
|
if model_artifact:
|
||||||
updates["last_model_artifact"] = source.model_artifact
|
updates["last_model_artifact"] = model_artifact
|
||||||
st.update(**updates)
|
st.update(**updates)
|
||||||
CONSOLE.print(f"[green]✓[/green] Compile job: [bold]{result['job_id']}[/bold]")
|
CONSOLE.print(f"[green]✓[/green] Compile job: [bold]{result['job_id']}[/bold]")
|
||||||
CONSOLE.print(f"[green]✓[/green] Compiled model: [bold]{result['model_id']}[/bold]")
|
CONSOLE.print(f"[green]✓[/green] Compiled model: [bold]{result['model_id']}[/bold]")
|
||||||
@@ -350,7 +214,6 @@ def _validate_step(
|
|||||||
model_id: str | None,
|
model_id: str | None,
|
||||||
input_name: str | None,
|
input_name: str | None,
|
||||||
) -> str:
|
) -> str:
|
||||||
_validate_device(cfg)
|
|
||||||
specs = _input_specs(cfg)
|
specs = _input_specs(cfg)
|
||||||
resolved_model_id = _model_id_or_state(config_path, model_id)
|
resolved_model_id = _model_id_or_state(config_path, model_id)
|
||||||
try:
|
try:
|
||||||
@@ -362,9 +225,8 @@ def _validate_step(
|
|||||||
run = datetime.now().strftime("%Y%m%d-%H%M%S")
|
run = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||||
out_dir = Path(cfg.aihub.output_dir) / run / "validation"
|
out_dir = Path(cfg.aihub.output_dir) / run / "validation"
|
||||||
try:
|
try:
|
||||||
hub_model = hub.get_model(resolved_model_id)
|
|
||||||
result = aihub_jobs.submit_inference_job(
|
result = aihub_jobs.submit_inference_job(
|
||||||
hub_model,
|
resolved_model_id,
|
||||||
cfg.aihub.device,
|
cfg.aihub.device,
|
||||||
inputs,
|
inputs,
|
||||||
out_dir,
|
out_dir,
|
||||||
@@ -385,12 +247,10 @@ def _validate_step(
|
|||||||
|
|
||||||
|
|
||||||
def _profile_step(cfg: Config, config_path: str, model_id: str | None) -> str:
|
def _profile_step(cfg: Config, config_path: str, model_id: str | None) -> str:
|
||||||
_validate_device(cfg)
|
|
||||||
resolved_model_id = _model_id_or_state(config_path, model_id)
|
resolved_model_id = _model_id_or_state(config_path, model_id)
|
||||||
try:
|
try:
|
||||||
hub_model = hub.get_model(resolved_model_id)
|
|
||||||
result = aihub_jobs.submit_profile_job(
|
result = aihub_jobs.submit_profile_job(
|
||||||
hub_model,
|
resolved_model_id,
|
||||||
cfg.aihub.device,
|
cfg.aihub.device,
|
||||||
cfg.aihub.profile_options,
|
cfg.aihub.profile_options,
|
||||||
job_name=_job_name(cfg, "profile"),
|
job_name=_job_name(cfg, "profile"),
|
||||||
@@ -403,24 +263,9 @@ def _profile_step(cfg: Config, config_path: str, model_id: str | None) -> str:
|
|||||||
return str(result["job_id"])
|
return str(result["job_id"])
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def optimize(
|
|
||||||
from_job: str | None = typer.Option(None, "--from-job", help="Training job name whose model artifact should optimize"),
|
|
||||||
model_s3_uri: str | None = typer.Option(None, "--model-s3-uri", help="S3 URI of model.tar.gz to optimize"),
|
|
||||||
onnx_path: str | None = typer.Option(
|
|
||||||
None, "--onnx-path", help="Local ONNX path or ONNX path inside extracted artifact"
|
|
||||||
),
|
|
||||||
config: str = CONFIG_OPT,
|
|
||||||
) -> None:
|
|
||||||
"""Optimize an external model into a Workbench-produced ONNX model."""
|
|
||||||
cfg = load_cfg(config)
|
|
||||||
_optimize_step(cfg, config, from_job, model_s3_uri, onnx_path)
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def quantize(
|
def quantize(
|
||||||
calibration_path: Path = typer.Argument(..., help="Calibration .npz file or directory of .npy samples"),
|
calibration_path: Path = typer.Argument(..., help="Calibration .npz file or directory of .npy samples"),
|
||||||
model_id: str | None = typer.Option(None, "--model-id", help="AI Hub optimized ONNX model ID"),
|
|
||||||
from_job: str | None = typer.Option(None, "--from-job", help="Training job name whose model artifact should quantize"),
|
from_job: str | None = typer.Option(None, "--from-job", help="Training job name whose model artifact should quantize"),
|
||||||
model_s3_uri: str | None = typer.Option(None, "--model-s3-uri", help="S3 URI of model.tar.gz to quantize"),
|
model_s3_uri: str | None = typer.Option(None, "--model-s3-uri", help="S3 URI of model.tar.gz to quantize"),
|
||||||
onnx_path: str | None = typer.Option(
|
onnx_path: str | None = typer.Option(
|
||||||
@@ -430,15 +275,7 @@ def quantize(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Quantize an ONNX model to INT8."""
|
"""Quantize an ONNX model to INT8."""
|
||||||
cfg = load_cfg(config)
|
cfg = load_cfg(config)
|
||||||
_quantize_step(
|
_quantize_step(cfg, config, calibration_path, from_job, model_s3_uri, onnx_path)
|
||||||
cfg,
|
|
||||||
config,
|
|
||||||
calibration_path,
|
|
||||||
model_id=model_id,
|
|
||||||
from_job=from_job,
|
|
||||||
model_s3_uri=model_s3_uri,
|
|
||||||
onnx_path=onnx_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
@@ -453,14 +290,7 @@ def compile(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Compile a model for the configured Qualcomm AI Hub target."""
|
"""Compile a model for the configured Qualcomm AI Hub target."""
|
||||||
cfg = load_cfg(config)
|
cfg = load_cfg(config)
|
||||||
_compile_step(
|
_compile_step(cfg, config, model_id, from_job, model_s3_uri, onnx_path, prefer_quantized=True)
|
||||||
cfg,
|
|
||||||
config,
|
|
||||||
model_id=model_id,
|
|
||||||
from_job=from_job,
|
|
||||||
model_s3_uri=model_s3_uri,
|
|
||||||
onnx_path=onnx_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
@@ -489,7 +319,7 @@ def profile(
|
|||||||
def upload(
|
def upload(
|
||||||
calibration_path: Path = typer.Argument(..., help="Calibration .npz file or directory of .npy samples"),
|
calibration_path: Path = typer.Argument(..., help="Calibration .npz file or directory of .npy samples"),
|
||||||
input_file: Path = typer.Argument(..., help="Validation .npz or .npy inputs to run on device"),
|
input_file: Path = typer.Argument(..., help="Validation .npz or .npy inputs to run on device"),
|
||||||
from_step: UploadStep = typer.Option(UploadStep.optimize, "--from-step", help="Resume from this Workbench step"),
|
from_step: UploadStep = typer.Option(UploadStep.quantize, "--from-step", help="Resume from this Workbench step"),
|
||||||
from_job: str | None = typer.Option(None, "--from-job", help="Training job name whose model artifact should upload"),
|
from_job: str | None = typer.Option(None, "--from-job", help="Training job name whose model artifact should upload"),
|
||||||
model_s3_uri: str | None = typer.Option(None, "--model-s3-uri", help="S3 URI of model.tar.gz to upload"),
|
model_s3_uri: str | None = typer.Option(None, "--model-s3-uri", help="S3 URI of model.tar.gz to upload"),
|
||||||
onnx_path: str | None = typer.Option(
|
onnx_path: str | None = typer.Option(
|
||||||
@@ -498,48 +328,25 @@ def upload(
|
|||||||
input_name: str | None = typer.Option(None, "--input-name", help="Input name for .npy validation files"),
|
input_name: str | None = typer.Option(None, "--input-name", help="Input name for .npy validation files"),
|
||||||
config: str = CONFIG_OPT,
|
config: str = CONFIG_OPT,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Optimize, quantize, optionally compile, validate, and profile a model."""
|
"""Run the four Workbench upload steps: quantize, compile, validate, and profile."""
|
||||||
cfg = load_cfg(config)
|
cfg = load_cfg(config)
|
||||||
steps = [UploadStep.optimize, UploadStep.quantize, UploadStep.compile, UploadStep.validate, UploadStep.profile]
|
steps = [UploadStep.quantize, UploadStep.compile, UploadStep.validate, UploadStep.profile]
|
||||||
selected = steps[steps.index(from_step) :]
|
selected = steps[steps.index(from_step) :]
|
||||||
|
|
||||||
optimized_model_id: str | None = None
|
|
||||||
quantized_model_id: str | None = None
|
quantized_model_id: str | None = None
|
||||||
compiled_model_id: str | None = None
|
compiled_model_id: str | None = None
|
||||||
if UploadStep.optimize in selected:
|
|
||||||
optimized_model_id = _optimize_step(cfg, config, from_job, model_s3_uri, onnx_path)
|
|
||||||
if UploadStep.quantize in selected:
|
if UploadStep.quantize in selected:
|
||||||
if UploadStep.optimize not in selected:
|
quantized_model_id = _quantize_step(cfg, config, calibration_path, from_job, model_s3_uri, onnx_path)
|
||||||
optimized_model_id = state_ops.store(config).get_last_optimized_model_id()
|
if UploadStep.compile in selected:
|
||||||
if not optimized_model_id:
|
compiled_model_id = _compile_step(
|
||||||
CONSOLE.print(
|
|
||||||
"[red]No optimized ONNX model found. Resume from --from-step optimize or run "
|
|
||||||
"'qc-cli ai-hub optimize' first.[/red]"
|
|
||||||
)
|
|
||||||
raise typer.Exit(1)
|
|
||||||
quantized_model_id = _quantize_step(
|
|
||||||
cfg,
|
cfg,
|
||||||
config,
|
config,
|
||||||
calibration_path,
|
model_id=quantized_model_id,
|
||||||
model_id=optimized_model_id,
|
from_job=from_job,
|
||||||
|
model_s3_uri=model_s3_uri,
|
||||||
|
onnx_path=onnx_path,
|
||||||
|
prefer_quantized=True,
|
||||||
)
|
)
|
||||||
if UploadStep.compile in selected:
|
|
||||||
if cfg.aihub.target_runtime == "onnx":
|
|
||||||
compiled_model_id = quantized_model_id or state_ops.store(config).get_last_quantized_model_id()
|
|
||||||
if not compiled_model_id:
|
|
||||||
CONSOLE.print(
|
|
||||||
"[red]No quantized ONNX model found. Resume from --from-step quantize or run "
|
|
||||||
"'qc-cli ai-hub quantize' first.[/red]"
|
|
||||||
)
|
|
||||||
raise typer.Exit(1)
|
|
||||||
state_ops.store(config).update(last_compiled_model_id=compiled_model_id)
|
|
||||||
CONSOLE.print("[green]✓[/green] Target runtime is ONNX; skipping final compile.")
|
|
||||||
else:
|
|
||||||
compiled_model_id = _compile_step(
|
|
||||||
cfg,
|
|
||||||
config,
|
|
||||||
model_id=quantized_model_id,
|
|
||||||
)
|
|
||||||
if UploadStep.validate in selected:
|
if UploadStep.validate in selected:
|
||||||
_validate_step(cfg, config, input_file, compiled_model_id, input_name)
|
_validate_step(cfg, config, input_file, compiled_model_id, input_name)
|
||||||
if UploadStep.profile in selected:
|
if UploadStep.profile in selected:
|
||||||
|
|||||||
@@ -150,6 +150,35 @@ def status(config: str = CONFIG_OPT) -> None:
|
|||||||
CONSOLE.print(table)
|
CONSOLE.print(table)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command(name="mlflow-url")
|
||||||
|
def mlflow_url(config: str = CONFIG_OPT) -> None:
|
||||||
|
"""Print a presigned URL for the configured MLflow tracking server."""
|
||||||
|
cfg = load_cfg(config)
|
||||||
|
tracking_server_name = cfg.effective_mlflow_tracking_server_name
|
||||||
|
if not tracking_server_name:
|
||||||
|
CONSOLE.print("[red]MLflow is disabled in config.yaml.[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
url = mlflow.create_presigned_tracking_server_url(
|
||||||
|
cfg.aws.region,
|
||||||
|
cfg.aws.profile,
|
||||||
|
tracking_server_name,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
CONSOLE.print("[yellow]Could not create a SageMaker MLflow UI URL.[/yellow]")
|
||||||
|
CONSOLE.print(f"Tracking server: [cyan]{tracking_server_name}[/cyan]")
|
||||||
|
CONSOLE.print(f"Reason: {e}")
|
||||||
|
CONSOLE.print(
|
||||||
|
"This command can create presigned URLs only for MLflow tracking servers managed by "
|
||||||
|
"Amazon SageMaker. If this is an external MLflow server, open it with that server's own URL."
|
||||||
|
)
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
CONSOLE.print(f"MLflow tracking server: [cyan]{tracking_server_name}[/cyan]")
|
||||||
|
CONSOLE.print(f"MLflow UI: {url}")
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def destroy(
|
def destroy(
|
||||||
config: str = CONFIG_OPT,
|
config: str = CONFIG_OPT,
|
||||||
|
|||||||
@@ -1,40 +0,0 @@
|
|||||||
import secrets
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import typer
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from src.commands.utils import CONSOLE
|
|
||||||
from src.config import GENERATED_STACK_PREFIX, Config, InfraConfig, S3Config
|
|
||||||
|
|
||||||
app = typer.Typer()
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def init(
|
|
||||||
output: str = typer.Option("config.yaml", help="Destination path for the config file"),
|
|
||||||
force: bool = typer.Option(False, "--force", "-f", help="Overwrite an existing config file"),
|
|
||||||
) -> None:
|
|
||||||
"""Write a starter config.yaml to the current directory."""
|
|
||||||
dest = Path(output)
|
|
||||||
if dest.exists() and not force:
|
|
||||||
CONSOLE.print(f"[yellow]{dest} already exists.[/yellow] Use --force to overwrite.")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
config = _new_isolated_config()
|
|
||||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
config_data = config.model_dump(mode="json")
|
|
||||||
config_data["sagemaker"].pop("role_name", None)
|
|
||||||
with open(dest, "w") as f:
|
|
||||||
yaml.safe_dump(config_data, f, sort_keys=False)
|
|
||||||
|
|
||||||
CONSOLE.print(f"[green]✓[/green] Config written to [bold]{dest}[/bold]")
|
|
||||||
CONSOLE.print("Edit [cyan]sagemaker.training.image_uri[/cyan] before running training commands.")
|
|
||||||
|
|
||||||
|
|
||||||
def _new_isolated_config() -> Config:
|
|
||||||
suffix = secrets.token_hex(6)
|
|
||||||
namespace = f"{GENERATED_STACK_PREFIX}{suffix}"
|
|
||||||
config = Config(infra=InfraConfig(stack_name=namespace))
|
|
||||||
config.s3 = S3Config(bucket=f"{namespace}-data")
|
|
||||||
return config
|
|
||||||
@@ -1,95 +0,0 @@
|
|||||||
import webbrowser
|
|
||||||
|
|
||||||
import typer
|
|
||||||
|
|
||||||
from src import state as state_ops
|
|
||||||
from src.aws import mlflow as aws_mlflow
|
|
||||||
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
|
||||||
from src.config import MlflowMode
|
|
||||||
from src.tracking.upload import upload_training_metrics
|
|
||||||
|
|
||||||
app = typer.Typer(help="Manage MLflow tracking server access")
|
|
||||||
|
|
||||||
|
|
||||||
@app.command(name="open")
|
|
||||||
def open_mlflow(config: str = CONFIG_OPT) -> None:
|
|
||||||
"""Open a presigned URL for the configured MLflow tracking server."""
|
|
||||||
cfg = load_cfg(config)
|
|
||||||
tracking_server_name = cfg.effective_mlflow_tracking_server_name
|
|
||||||
if not tracking_server_name:
|
|
||||||
CONSOLE.print("[red]MLflow is disabled in config.yaml.[/red]")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
try:
|
|
||||||
url = aws_mlflow.create_presigned_tracking_server_url(
|
|
||||||
cfg.aws.region,
|
|
||||||
cfg.aws.profile,
|
|
||||||
tracking_server_name,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
CONSOLE.print("[yellow]Could not create a SageMaker MLflow UI URL.[/yellow]")
|
|
||||||
CONSOLE.print(f"Tracking server: [cyan]{tracking_server_name}[/cyan]")
|
|
||||||
CONSOLE.print(f"Reason: {e}")
|
|
||||||
CONSOLE.print(
|
|
||||||
"This command can create presigned URLs only for MLflow tracking servers managed by "
|
|
||||||
"Amazon SageMaker. If this is an external MLflow server, open it with that server's own URL."
|
|
||||||
)
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
CONSOLE.print(f"MLflow tracking server: [cyan]{tracking_server_name}[/cyan]")
|
|
||||||
CONSOLE.print(f"MLflow UI: {url}")
|
|
||||||
if webbrowser.open(url):
|
|
||||||
CONSOLE.print("[green]✓[/green] Opened MLflow UI in your browser.")
|
|
||||||
else:
|
|
||||||
CONSOLE.print("[yellow]Could not open a browser automatically. Open the URL above manually.[/yellow]")
|
|
||||||
|
|
||||||
|
|
||||||
@app.command(name="upload-metrics")
|
|
||||||
def upload_metrics(
|
|
||||||
job_name: str | None = typer.Argument(None, help="Training job name (default: last submitted job)"),
|
|
||||||
force: bool = typer.Option(False, "--force", help="Upload again even if this job is marked as uploaded"),
|
|
||||||
config: str = CONFIG_OPT,
|
|
||||||
) -> None:
|
|
||||||
"""Upload a completed training job's metric history to MLflow."""
|
|
||||||
cfg = load_cfg(config)
|
|
||||||
if cfg.mlflow.mode is MlflowMode.disabled:
|
|
||||||
CONSOLE.print("[red]MLflow is disabled in config.yaml.[/red]")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
st = state_ops.store(config)
|
|
||||||
if not job_name:
|
|
||||||
job_name = st.get_last_training_job()
|
|
||||||
if not job_name:
|
|
||||||
CONSOLE.print(
|
|
||||||
"[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]"
|
|
||||||
)
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
if st.get_training_job(job_name).get("mlflow_metrics_uploaded") and not force:
|
|
||||||
CONSOLE.print(f"[green]✓[/green] Metrics already uploaded for [cyan]{job_name}[/cyan].")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = upload_training_metrics(
|
|
||||||
job_name=job_name,
|
|
||||||
config_path=config,
|
|
||||||
cfg=cfg,
|
|
||||||
force=force,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
if result.metrics_history_uploaded:
|
|
||||||
CONSOLE.print(f"[green]✓[/green] Uploaded training metrics for [cyan]{job_name}[/cyan].")
|
|
||||||
else:
|
|
||||||
CONSOLE.print(
|
|
||||||
f"[yellow]No training_metrics.json was found in the SageMaker model artifact for "
|
|
||||||
f"[cyan]{job_name}[/cyan]. Uploaded SageMaker final metrics only.[/yellow]"
|
|
||||||
)
|
|
||||||
CONSOLE.print(f"MLflow run: [cyan]{result.run_id}[/cyan]")
|
|
||||||
if result.registered_model_version:
|
|
||||||
CONSOLE.print(
|
|
||||||
f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] "
|
|
||||||
"([cyan]experiment-latest[/cyan])"
|
|
||||||
)
|
|
||||||
@@ -1,4 +1,3 @@
|
|||||||
import time
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -12,7 +11,6 @@ from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
|||||||
from src.config import Config, MlflowMode
|
from src.config import Config, MlflowMode
|
||||||
from src.infra.state import read_infra_state
|
from src.infra.state import read_infra_state
|
||||||
from src.tracking.mlflow import MlflowTracker
|
from src.tracking.mlflow import MlflowTracker
|
||||||
from src.tracking.upload import upload_training_metrics
|
|
||||||
|
|
||||||
app = typer.Typer(help="Manage SageMaker training jobs")
|
app = typer.Typer(help="Manage SageMaker training jobs")
|
||||||
|
|
||||||
@@ -23,8 +21,6 @@ _STATUS_COLOR = {
|
|||||||
"Stopping": "yellow",
|
"Stopping": "yellow",
|
||||||
"Stopped": "dim",
|
"Stopped": "dim",
|
||||||
}
|
}
|
||||||
_TERMINAL_STATUSES = {"Completed", "Failed", "Stopped"}
|
|
||||||
DEFAULT_POLL_INTERVAL_SECONDS = 30
|
|
||||||
|
|
||||||
|
|
||||||
def _tracker(cfg):
|
def _tracker(cfg):
|
||||||
@@ -52,100 +48,11 @@ def _sagemaker_role_arn(config_path: str, cfg: Config) -> str:
|
|||||||
raise RuntimeError("SageMaker role not found in infra state. Run 'qc-cli infra setup' first.")
|
raise RuntimeError("SageMaker role not found in infra state. Run 'qc-cli infra setup' first.")
|
||||||
|
|
||||||
|
|
||||||
def _print_training_status(status: sm_ops.TrainingJobStatus) -> None:
|
|
||||||
color = _STATUS_COLOR.get(status.status, "white")
|
|
||||||
CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]")
|
|
||||||
CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]")
|
|
||||||
if status.created:
|
|
||||||
CONSOLE.print(f"Created: {status.created}")
|
|
||||||
if status.model_artifacts:
|
|
||||||
CONSOLE.print(f"Artifacts: {status.model_artifacts}")
|
|
||||||
if status.failure_reason:
|
|
||||||
CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]")
|
|
||||||
|
|
||||||
|
|
||||||
def _wait_and_upload_metrics(
|
|
||||||
*,
|
|
||||||
job_name: str,
|
|
||||||
poll_interval: int,
|
|
||||||
config_path: str,
|
|
||||||
cfg: Config,
|
|
||||||
) -> None:
|
|
||||||
st = state_ops.store(config_path)
|
|
||||||
previous_status: str | None = None
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
training_status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
|
||||||
if training_status.status != previous_status:
|
|
||||||
color = _STATUS_COLOR.get(training_status.status, "white")
|
|
||||||
CONSOLE.print(
|
|
||||||
f"Job [cyan]{training_status.name}[/cyan]: "
|
|
||||||
f"[{color}]{training_status.status}[/{color}]"
|
|
||||||
)
|
|
||||||
previous_status = training_status.status
|
|
||||||
if training_status.status in _TERMINAL_STATUSES:
|
|
||||||
_print_training_status(training_status)
|
|
||||||
if training_status.status != "Completed":
|
|
||||||
raise typer.Exit(1)
|
|
||||||
try:
|
|
||||||
result = upload_training_metrics(
|
|
||||||
job_name=job_name,
|
|
||||||
config_path=config_path,
|
|
||||||
cfg=cfg,
|
|
||||||
)
|
|
||||||
if result.metrics_history_uploaded:
|
|
||||||
CONSOLE.print(
|
|
||||||
f"[green]✓[/green] Uploaded training metrics to MLflow run "
|
|
||||||
f"[cyan]{result.run_id}[/cyan]."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
CONSOLE.print(
|
|
||||||
"[yellow]No training_metrics.json was found in the SageMaker model artifact. "
|
|
||||||
"Uploaded SageMaker final metrics only.[/yellow]"
|
|
||||||
)
|
|
||||||
if result.registered_model_version:
|
|
||||||
CONSOLE.print(
|
|
||||||
f"MLflow model version: [cyan]{result.registered_model_version}[/cyan] "
|
|
||||||
"([cyan]experiment-latest[/cyan])"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
CONSOLE.print(f"[red]MLflow metric upload failed: {e}[/red]")
|
|
||||||
CONSOLE.print(
|
|
||||||
f"Retry with [cyan]qc-cli mlflow upload-metrics {job_name}[/cyan]."
|
|
||||||
)
|
|
||||||
raise typer.Exit(1)
|
|
||||||
job_state = st.get_training_job(job_name)
|
|
||||||
if job_state.get("mlflow_run_id") and cfg.mlflow.mode is not MlflowMode.disabled:
|
|
||||||
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
|
|
||||||
return
|
|
||||||
time.sleep(poll_interval)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
CONSOLE.print("[yellow]Stopped waiting. The SageMaker training job is still running.[/yellow]")
|
|
||||||
raise typer.Exit(130)
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def start(
|
def start(config: str = CONFIG_OPT) -> None:
|
||||||
upload_metrics: bool = typer.Option(
|
|
||||||
False,
|
|
||||||
"--upload-metrics",
|
|
||||||
help="Wait for completion, then upload training metrics to MLflow",
|
|
||||||
),
|
|
||||||
poll_interval: int = typer.Option(
|
|
||||||
DEFAULT_POLL_INTERVAL_SECONDS,
|
|
||||||
"--poll-interval",
|
|
||||||
min=1,
|
|
||||||
help="Seconds between status checks when --upload-metrics is used",
|
|
||||||
),
|
|
||||||
config: str = CONFIG_OPT,
|
|
||||||
) -> None:
|
|
||||||
"""Submit a SageMaker training job."""
|
"""Submit a SageMaker training job."""
|
||||||
cfg = load_cfg(config)
|
cfg = load_cfg(config)
|
||||||
|
|
||||||
if upload_metrics and cfg.mlflow.mode is MlflowMode.disabled:
|
|
||||||
CONSOLE.print("[red]--upload-metrics requires MLflow to be enabled in config.yaml.[/red]")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
if not cfg.sagemaker.training.image_uri:
|
if not cfg.sagemaker.training.image_uri:
|
||||||
CONSOLE.print("[red]sagemaker.training.image_uri is required in config.yaml.[/red]")
|
CONSOLE.print("[red]sagemaker.training.image_uri is required in config.yaml.[/red]")
|
||||||
CONSOLE.print(
|
CONSOLE.print(
|
||||||
@@ -182,36 +89,20 @@ def start(
|
|||||||
|
|
||||||
st = state_ops.store(config)
|
st = state_ops.store(config)
|
||||||
st.set_last_training_job(job_name)
|
st.set_last_training_job(job_name)
|
||||||
try:
|
run_id = tracker.start_training_run(
|
||||||
run_id = tracker.start_training_run(
|
training_job,
|
||||||
training_job,
|
region=cfg.aws.region,
|
||||||
region=cfg.aws.region,
|
profile=cfg.aws.profile,
|
||||||
profile=cfg.aws.profile,
|
role_arn=role_arn,
|
||||||
role_arn=role_arn,
|
)
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
run_id = None
|
|
||||||
CONSOLE.print(f"[yellow]MLflow run creation failed: {e}[/yellow]")
|
|
||||||
CONSOLE.print(
|
|
||||||
"The SageMaker job is still running. Upload metrics after completion with "
|
|
||||||
f"[cyan]qc-cli mlflow upload-metrics {job_name}[/cyan]."
|
|
||||||
)
|
|
||||||
if run_id:
|
if run_id:
|
||||||
st.update_training_job(job_name, mlflow_run_id=run_id)
|
st.update_training_job(job_name, mlflow_run_id=run_id)
|
||||||
|
|
||||||
CONSOLE.print(f"[green]✓[/green] Job submitted: [bold]{job_name}[/bold]")
|
CONSOLE.print(f"[green]✓[/green] Job submitted: [bold]{job_name}[/bold]")
|
||||||
if run_id:
|
if run_id:
|
||||||
CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]")
|
CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]")
|
||||||
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]")
|
CONSOLE.print("Open MLflow: [cyan]qc-cli infra mlflow-url[/cyan]")
|
||||||
if upload_metrics:
|
CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]")
|
||||||
_wait_and_upload_metrics(
|
|
||||||
job_name=job_name,
|
|
||||||
poll_interval=poll_interval,
|
|
||||||
config_path=config,
|
|
||||||
cfg=cfg,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]")
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
@@ -232,7 +123,35 @@ def status(
|
|||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
|
|
||||||
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
||||||
_print_training_status(status)
|
color = _STATUS_COLOR.get(status.status, "white")
|
||||||
|
|
||||||
|
CONSOLE.print(f"Job: [cyan]{status.name}[/cyan]")
|
||||||
|
CONSOLE.print(f"Status: [{color}]{status.status}[/{color}]")
|
||||||
|
if status.created:
|
||||||
|
CONSOLE.print(f"Created: {status.created}")
|
||||||
|
if status.model_artifacts:
|
||||||
|
CONSOLE.print(f"Artifacts: {status.model_artifacts}")
|
||||||
|
if status.failure_reason:
|
||||||
|
CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]")
|
||||||
|
|
||||||
|
job_state = st.get_training_job(job_name)
|
||||||
|
run_id = job_state.get("mlflow_run_id")
|
||||||
|
already_registered = job_state.get("registered_model_version")
|
||||||
|
if run_id and not already_registered and status.status in {"Completed", "Failed", "Stopped"}:
|
||||||
|
tracker = _tracker(cfg)
|
||||||
|
version = tracker.finalize_training_run(
|
||||||
|
run_id=str(run_id),
|
||||||
|
training_job_status=status,
|
||||||
|
)
|
||||||
|
updates = {"mlflow_finalized_status": status.status}
|
||||||
|
if version:
|
||||||
|
updates["registered_model_version"] = version
|
||||||
|
st.update_training_job(job_name, **updates)
|
||||||
|
if version:
|
||||||
|
st.set_latest_experiment_model_version(version)
|
||||||
|
CONSOLE.print(f"MLflow model version: [cyan]{version}[/cyan] ([cyan]experiment-latest[/cyan])")
|
||||||
|
if run_id and cfg.mlflow.mode is not MlflowMode.disabled:
|
||||||
|
CONSOLE.print("Open MLflow: [cyan]qc-cli infra mlflow-url[/cyan]")
|
||||||
|
|
||||||
|
|
||||||
@app.command(name="list")
|
@app.command(name="list")
|
||||||
|
|||||||
@@ -1,70 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import typer
|
|
||||||
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
|
|
||||||
|
|
||||||
from src.aws import s3 as s3_ops
|
|
||||||
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
|
||||||
|
|
||||||
app = typer.Typer()
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def upload(
|
|
||||||
path: Path = typer.Argument(..., help="Local file or directory to upload"),
|
|
||||||
s3_key: str | None = typer.Option(None, "--s3-key", help="S3 key for file uploads"),
|
|
||||||
config: str = CONFIG_OPT,
|
|
||||||
) -> None:
|
|
||||||
"""Upload a local file or directory to S3."""
|
|
||||||
cfg = load_cfg(config)
|
|
||||||
|
|
||||||
if path.is_file():
|
|
||||||
key = s3_key or f"{cfg.s3.data_prefix.rstrip('/')}/{path.name}"
|
|
||||||
try:
|
|
||||||
with CONSOLE.status(f"Uploading {path.name}..."):
|
|
||||||
uri = s3_ops.upload_file(cfg.aws.region, cfg.aws.profile, cfg.s3.bucket, str(path), key)
|
|
||||||
except Exception as e:
|
|
||||||
CONSOLE.print(f"[red]Upload failed: {e}[/red]")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
CONSOLE.print(f"[green]✓[/green] {path.name} -> {uri}")
|
|
||||||
return
|
|
||||||
|
|
||||||
if path.is_dir():
|
|
||||||
if s3_key is not None:
|
|
||||||
CONSOLE.print("[red]--s3-key can only be used when uploading a single file.[/red]")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
files = [file for file in path.rglob("*") if file.is_file()]
|
|
||||||
if not files:
|
|
||||||
CONSOLE.print("[yellow]No files found in directory.[/yellow]")
|
|
||||||
raise typer.Exit(0)
|
|
||||||
|
|
||||||
prefix = cfg.s3.data_prefix
|
|
||||||
CONSOLE.print(f"Uploading {len(files)} files to s3://{cfg.s3.bucket}/{prefix.rstrip('/')}/")
|
|
||||||
try:
|
|
||||||
with Progress(
|
|
||||||
SpinnerColumn(),
|
|
||||||
TextColumn("[progress.description]{task.description}"),
|
|
||||||
BarColumn(),
|
|
||||||
TaskProgressColumn(),
|
|
||||||
console=CONSOLE,
|
|
||||||
) as progress:
|
|
||||||
task = progress.add_task("Uploading...", total=len(files))
|
|
||||||
count = s3_ops.upload_dir(
|
|
||||||
cfg.aws.region,
|
|
||||||
cfg.aws.profile,
|
|
||||||
cfg.s3.bucket,
|
|
||||||
str(path),
|
|
||||||
prefix,
|
|
||||||
on_progress=lambda: progress.advance(task),
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
CONSOLE.print(f"[red]Upload failed: {e}[/red]")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
CONSOLE.print(f"[green]✓[/green] Uploaded {count} files to s3://{cfg.s3.bucket}/{prefix.rstrip('/')}/")
|
|
||||||
return
|
|
||||||
|
|
||||||
CONSOLE.print(f"[red]Path not found: {path}[/red]")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
@@ -4,8 +4,7 @@ from typing import Any, Literal, TypedDict
|
|||||||
|
|
||||||
from mypy_boto3_s3.literals import BucketLocationConstraintType
|
from mypy_boto3_s3.literals import BucketLocationConstraintType
|
||||||
from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
|
from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
from qai_hub.client import Device
|
|
||||||
|
|
||||||
|
|
||||||
class MlflowMode(StrEnum):
|
class MlflowMode(StrEnum):
|
||||||
@@ -82,7 +81,7 @@ class SageMakerConfig(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class AIHubConfig(BaseModel):
|
class AIHubConfig(BaseModel):
|
||||||
device: Device = Field(default_factory=lambda: Device("Samsung Galaxy S25 (Family)"))
|
device: str = "Samsung Galaxy S25 (Family)"
|
||||||
target_runtime: str = "tflite"
|
target_runtime: str = "tflite"
|
||||||
input_specs: dict[str, tuple[list[int], str]] = Field(default_factory=dict)
|
input_specs: dict[str, tuple[list[int], str]] = Field(default_factory=dict)
|
||||||
job_name: str | None = None
|
job_name: str | None = None
|
||||||
@@ -92,13 +91,6 @@ class AIHubConfig(BaseModel):
|
|||||||
quantize_options: str | None = None
|
quantize_options: str | None = None
|
||||||
output_dir: str = "build/qai-hub"
|
output_dir: str = "build/qai-hub"
|
||||||
|
|
||||||
@field_validator("device", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def parse_device(cls, value: Any) -> Any:
|
|
||||||
if isinstance(value, str):
|
|
||||||
return Device(value)
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
class MlflowConfig(BaseModel):
|
class MlflowConfig(BaseModel):
|
||||||
mode: MlflowMode = MlflowMode.disabled
|
mode: MlflowMode = MlflowMode.disabled
|
||||||
|
|||||||
111
src/main.py
111
src/main.py
@@ -1,14 +1,115 @@
|
|||||||
import typer
|
import secrets
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from src.commands import ai_hub, infra, init, mlflow, train, upload
|
import typer
|
||||||
|
import yaml
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
|
||||||
|
|
||||||
|
from src.aws import s3 as s3_ops
|
||||||
|
from src.commands import ai_hub, infra, train
|
||||||
|
from src.commands.utils import CONFIG_OPT, load_cfg
|
||||||
|
from src.config import GENERATED_STACK_PREFIX, Config, InfraConfig, S3Config
|
||||||
|
|
||||||
app = typer.Typer(
|
app = typer.Typer(
|
||||||
help="qc-cli: End-to-end model managment for Qualcomm AI Hub.",
|
help="qc-cli: End-to-end model managment for Qualcomm AI Hub.",
|
||||||
no_args_is_help=True,
|
no_args_is_help=True,
|
||||||
)
|
)
|
||||||
app.add_typer(init.app)
|
|
||||||
app.add_typer(upload.app)
|
|
||||||
app.add_typer(mlflow.app, name="mlflow")
|
|
||||||
app.add_typer(infra.app, name="infra")
|
app.add_typer(infra.app, name="infra")
|
||||||
app.add_typer(train.app, name="train")
|
app.add_typer(train.app, name="train")
|
||||||
app.add_typer(ai_hub.app, name="ai-hub")
|
app.add_typer(ai_hub.app, name="ai-hub")
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def init(
|
||||||
|
output: str = typer.Option("config.yaml", help="Destination path for the config file"),
|
||||||
|
force: bool = typer.Option(False, "--force", "-f", help="Overwrite an existing config file"),
|
||||||
|
) -> None:
|
||||||
|
"""Write a starter config.yaml to the current directory."""
|
||||||
|
dest = Path(output)
|
||||||
|
if dest.exists() and not force:
|
||||||
|
console.print(f"[yellow]{dest} already exists.[/yellow] Use --force to overwrite.")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
config = _new_isolated_config()
|
||||||
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
config_data = config.model_dump(mode="json")
|
||||||
|
config_data["sagemaker"].pop("role_name", None)
|
||||||
|
with open(dest, "w") as f:
|
||||||
|
yaml.safe_dump(config_data, f, sort_keys=False)
|
||||||
|
|
||||||
|
console.print(f"[green]✓[/green] Config written to [bold]{dest}[/bold]")
|
||||||
|
console.print(
|
||||||
|
"Edit [cyan]sagemaker.training.image_uri[/cyan] before running training commands."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _new_isolated_config() -> Config:
|
||||||
|
suffix = secrets.token_hex(6)
|
||||||
|
namespace = f"{GENERATED_STACK_PREFIX}{suffix}"
|
||||||
|
config = Config(infra=InfraConfig(stack_name=namespace))
|
||||||
|
config.s3 = S3Config(bucket=f"{namespace}-data")
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def upload(
|
||||||
|
path: Path = typer.Argument(..., help="Local file or directory to upload"),
|
||||||
|
s3_key: str | None = typer.Option(None, "--s3-key", help="S3 key for file uploads"),
|
||||||
|
config: str = CONFIG_OPT,
|
||||||
|
) -> None:
|
||||||
|
"""Upload a local file or directory to S3."""
|
||||||
|
cfg = load_cfg(config)
|
||||||
|
|
||||||
|
if path.is_file():
|
||||||
|
key = s3_key or f"{cfg.s3.data_prefix.rstrip('/')}/{path.name}"
|
||||||
|
try:
|
||||||
|
with console.status(f"Uploading {path.name}..."):
|
||||||
|
uri = s3_ops.upload_file(cfg.aws.region, cfg.aws.profile, cfg.s3.bucket, str(path), key)
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]Upload failed: {e}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
console.print(f"[green]✓[/green] {path.name} -> {uri}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if path.is_dir():
|
||||||
|
if s3_key is not None:
|
||||||
|
console.print("[red]--s3-key can only be used when uploading a single file.[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
files = [file for file in path.rglob("*") if file.is_file()]
|
||||||
|
if not files:
|
||||||
|
console.print("[yellow]No files found in directory.[/yellow]")
|
||||||
|
raise typer.Exit(0)
|
||||||
|
|
||||||
|
prefix = cfg.s3.data_prefix
|
||||||
|
console.print(f"Uploading {len(files)} files to s3://{cfg.s3.bucket}/{prefix.rstrip('/')}/")
|
||||||
|
try:
|
||||||
|
with Progress(
|
||||||
|
SpinnerColumn(),
|
||||||
|
TextColumn("[progress.description]{task.description}"),
|
||||||
|
BarColumn(),
|
||||||
|
TaskProgressColumn(),
|
||||||
|
console=console,
|
||||||
|
) as progress:
|
||||||
|
task = progress.add_task("Uploading...", total=len(files))
|
||||||
|
count = s3_ops.upload_dir(
|
||||||
|
cfg.aws.region,
|
||||||
|
cfg.aws.profile,
|
||||||
|
cfg.s3.bucket,
|
||||||
|
str(path),
|
||||||
|
prefix,
|
||||||
|
on_progress=lambda: progress.advance(task),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]Upload failed: {e}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
console.print(f"[green]✓[/green] Uploaded {count} files to s3://{cfg.s3.bucket}/{prefix.rstrip('/')}/")
|
||||||
|
return
|
||||||
|
|
||||||
|
console.print(f"[red]Path not found: {path}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
|
|||||||
@@ -28,20 +28,31 @@ def _dataset_entries(inputs: dict[str, Any]) -> dict[str, list[Any]]:
|
|||||||
|
|
||||||
|
|
||||||
def submit_compile_job(
|
def submit_compile_job(
|
||||||
model: Model,
|
model: Any,
|
||||||
device: Device,
|
device_name: str,
|
||||||
input_specs: dict[str, tuple[tuple[int, ...], str]],
|
input_specs: dict[str, tuple[tuple[int, ...], str]],
|
||||||
target_runtime: str,
|
target_runtime: str,
|
||||||
options: str | None = None,
|
options: str | None = None,
|
||||||
job_name: str | None = None,
|
job_name: str | None = None,
|
||||||
|
model_name: str | None = None,
|
||||||
) -> ModelJobResult:
|
) -> ModelJobResult:
|
||||||
compile_options = f"--target_runtime {target_runtime}"
|
compile_options = f"--target_runtime {target_runtime}"
|
||||||
if options:
|
if options:
|
||||||
compile_options = f"{compile_options} {options}"
|
compile_options = f"{compile_options} {options}"
|
||||||
|
|
||||||
|
model_arg = model
|
||||||
|
if isinstance(model, Path):
|
||||||
|
model_arg = str(model)
|
||||||
|
elif isinstance(model, str):
|
||||||
|
candidate = Path(model)
|
||||||
|
model_arg = model if candidate.exists() or candidate.suffix else hub.get_model(model)
|
||||||
|
|
||||||
|
if model_name and isinstance(model_arg, str) and Path(model_arg).exists():
|
||||||
|
model_arg = hub.upload_model(model_arg, name=model_name)
|
||||||
|
|
||||||
job = hub.submit_compile_job(
|
job = hub.submit_compile_job(
|
||||||
model=model,
|
model=model_arg,
|
||||||
device=device,
|
device=Device(device_name),
|
||||||
name=job_name,
|
name=job_name,
|
||||||
input_specs=input_specs,
|
input_specs=input_specs,
|
||||||
options=compile_options,
|
options=compile_options,
|
||||||
@@ -53,15 +64,15 @@ def submit_compile_job(
|
|||||||
|
|
||||||
|
|
||||||
def submit_inference_job(
|
def submit_inference_job(
|
||||||
model: Model,
|
model_id: str,
|
||||||
device: Device,
|
device_name: str,
|
||||||
inputs: dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
output_dir: str | Path,
|
output_dir: str | Path,
|
||||||
job_name: str | None = None,
|
job_name: str | None = None,
|
||||||
) -> InferenceJobResult:
|
) -> InferenceJobResult:
|
||||||
job = hub.submit_inference_job(
|
job = hub.submit_inference_job(
|
||||||
model=model,
|
model=hub.get_model(model_id),
|
||||||
device=device,
|
device=Device(device_name),
|
||||||
inputs=_dataset_entries(inputs),
|
inputs=_dataset_entries(inputs),
|
||||||
name=job_name,
|
name=job_name,
|
||||||
)
|
)
|
||||||
@@ -72,14 +83,14 @@ def submit_inference_job(
|
|||||||
|
|
||||||
|
|
||||||
def submit_profile_job(
|
def submit_profile_job(
|
||||||
model: Model,
|
model_id: str,
|
||||||
device: Device,
|
device_name: str,
|
||||||
options: str | None = None,
|
options: str | None = None,
|
||||||
job_name: str | None = None,
|
job_name: str | None = None,
|
||||||
) -> ProfileJobResult:
|
) -> ProfileJobResult:
|
||||||
job = hub.submit_profile_job(
|
job = hub.submit_profile_job(
|
||||||
model=model,
|
model=hub.get_model(model_id),
|
||||||
device=device,
|
device=Device(device_name),
|
||||||
name=job_name,
|
name=job_name,
|
||||||
options=options or "",
|
options=options or "",
|
||||||
)
|
)
|
||||||
@@ -87,13 +98,17 @@ def submit_profile_job(
|
|||||||
|
|
||||||
|
|
||||||
def submit_quantize_job(
|
def submit_quantize_job(
|
||||||
model: Model,
|
model: str | Path,
|
||||||
calibration_data: dict[str, Any],
|
calibration_data: dict[str, Any],
|
||||||
options: str | None = None,
|
options: str | None = None,
|
||||||
job_name: str | None = None,
|
job_name: str | None = None,
|
||||||
|
model_name: str | None = None,
|
||||||
) -> ModelJobResult:
|
) -> ModelJobResult:
|
||||||
|
model_arg = str(model)
|
||||||
|
if model_name and Path(model_arg).exists():
|
||||||
|
model_arg = hub.upload_model(model_arg, name=model_name)
|
||||||
job = hub.submit_quantize_job(
|
job = hub.submit_quantize_job(
|
||||||
model=model,
|
model=model_arg,
|
||||||
calibration_data=_dataset_entries(calibration_data),
|
calibration_data=_dataset_entries(calibration_data),
|
||||||
weights_dtype=QuantizeDtype.INT8,
|
weights_dtype=QuantizeDtype.INT8,
|
||||||
activations_dtype=QuantizeDtype.INT8,
|
activations_dtype=QuantizeDtype.INT8,
|
||||||
|
|||||||
@@ -37,10 +37,6 @@ class CliStateStore:
|
|||||||
value = self.get("last_model_artifact")
|
value = self.get("last_model_artifact")
|
||||||
return str(value) if value else None
|
return str(value) if value else None
|
||||||
|
|
||||||
def get_last_optimized_model_id(self) -> str | None:
|
|
||||||
value = self.get("last_optimized_model_id")
|
|
||||||
return str(value) if value else None
|
|
||||||
|
|
||||||
def get_last_quantized_model_id(self) -> str | None:
|
def get_last_quantized_model_id(self) -> str | None:
|
||||||
value = self.get("last_quantized_model_id")
|
value = self.get("last_quantized_model_id")
|
||||||
return str(value) if value else None
|
return str(value) if value else None
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
from src.tracking.mlflow import FinalizeResult, MlflowTracker, NoopTracker, Tracker
|
from src.tracking.mlflow import MlflowTracker, NoopTracker, Tracker
|
||||||
|
|
||||||
__all__ = ["FinalizeResult", "MlflowTracker", "NoopTracker", "Tracker"]
|
__all__ = ["MlflowTracker", "NoopTracker", "Tracker"]
|
||||||
|
|||||||
@@ -1,93 +0,0 @@
|
|||||||
import json
|
|
||||||
import math
|
|
||||||
import tarfile
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import PurePosixPath
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
METRICS_ARTIFACT_NAME = "training_metrics.json"
|
|
||||||
METRICS_SCHEMA_VERSION = 1
|
|
||||||
MAX_METRICS_ARTIFACT_BYTES = 10 * 1024 * 1024
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class MetricStep:
|
|
||||||
step: int
|
|
||||||
metrics: dict[str, float]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class TrainingMetrics:
|
|
||||||
steps: list[MetricStep]
|
|
||||||
summary: dict[str, float]
|
|
||||||
raw: dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
def parse_training_metrics(data: bytes) -> TrainingMetrics:
|
|
||||||
try:
|
|
||||||
value = json.loads(data)
|
|
||||||
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
|
|
||||||
raise ValueError(f"{METRICS_ARTIFACT_NAME} is not valid JSON") from exc
|
|
||||||
if not isinstance(value, dict):
|
|
||||||
raise ValueError(f"{METRICS_ARTIFACT_NAME} must contain a JSON object")
|
|
||||||
if value.get("schema_version") != METRICS_SCHEMA_VERSION:
|
|
||||||
raise ValueError(f"Unsupported training metrics schema version: {value.get('schema_version')!r}")
|
|
||||||
|
|
||||||
raw_steps = value.get("steps")
|
|
||||||
if not isinstance(raw_steps, list):
|
|
||||||
raise ValueError("training metrics 'steps' must be a list")
|
|
||||||
|
|
||||||
steps: list[MetricStep] = []
|
|
||||||
previous_step: int | None = None
|
|
||||||
for index, raw_step in enumerate(raw_steps):
|
|
||||||
if not isinstance(raw_step, dict):
|
|
||||||
raise ValueError(f"training metrics step {index} must be an object")
|
|
||||||
step = raw_step.get("step")
|
|
||||||
if isinstance(step, bool) or not isinstance(step, int) or step < 0:
|
|
||||||
raise ValueError(f"training metrics step {index} has an invalid 'step'")
|
|
||||||
if previous_step is not None and step <= previous_step:
|
|
||||||
raise ValueError("training metrics steps must be unique and strictly increasing")
|
|
||||||
metrics = _numeric_metrics(raw_step.get("metrics"), f"training metrics step {step}")
|
|
||||||
steps.append(MetricStep(step=step, metrics=metrics))
|
|
||||||
previous_step = step
|
|
||||||
|
|
||||||
summary = _numeric_metrics(value.get("summary", {}), "training metrics summary")
|
|
||||||
return TrainingMetrics(steps=steps, summary=summary, raw=value)
|
|
||||||
|
|
||||||
|
|
||||||
def read_training_metrics_from_tar(archive_path: str) -> bytes | None:
|
|
||||||
with tarfile.open(archive_path, mode="r:*") as archive:
|
|
||||||
matches = [
|
|
||||||
member
|
|
||||||
for member in archive.getmembers()
|
|
||||||
if member.isfile() and PurePosixPath(member.name).name == METRICS_ARTIFACT_NAME
|
|
||||||
]
|
|
||||||
if not matches:
|
|
||||||
return None
|
|
||||||
if len(matches) > 1:
|
|
||||||
raise ValueError(f"Model archive contains multiple {METRICS_ARTIFACT_NAME} files")
|
|
||||||
if matches[0].size > MAX_METRICS_ARTIFACT_BYTES:
|
|
||||||
raise ValueError(
|
|
||||||
f"{METRICS_ARTIFACT_NAME} exceeds the {MAX_METRICS_ARTIFACT_BYTES}-byte size limit"
|
|
||||||
)
|
|
||||||
extracted = archive.extractfile(matches[0])
|
|
||||||
if extracted is None:
|
|
||||||
raise ValueError(f"Could not read {METRICS_ARTIFACT_NAME} from model archive")
|
|
||||||
return extracted.read()
|
|
||||||
|
|
||||||
|
|
||||||
def _numeric_metrics(value: Any, context: str) -> dict[str, float]:
|
|
||||||
if not isinstance(value, dict):
|
|
||||||
raise ValueError(f"{context} 'metrics' must be an object")
|
|
||||||
|
|
||||||
metrics: dict[str, float] = {}
|
|
||||||
for raw_name, raw_value in value.items():
|
|
||||||
if not isinstance(raw_name, str) or not raw_name:
|
|
||||||
raise ValueError(f"{context} contains an invalid metric name")
|
|
||||||
if isinstance(raw_value, bool) or not isinstance(raw_value, int | float):
|
|
||||||
raise ValueError(f"{context} metric '{raw_name}' must be numeric")
|
|
||||||
metric_value = float(raw_value)
|
|
||||||
if not math.isfinite(metric_value):
|
|
||||||
raise ValueError(f"{context} metric '{raw_name}' must be finite")
|
|
||||||
metrics[raw_name] = metric_value
|
|
||||||
return metrics
|
|
||||||
@@ -1,46 +1,18 @@
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Protocol
|
from typing import Any, Protocol
|
||||||
|
|
||||||
import mlflow
|
import mlflow
|
||||||
from mlflow.tracking import MlflowClient
|
from mlflow.tracking import MlflowClient
|
||||||
|
|
||||||
from src.aws import s3
|
from src.aws import mlflow as aws_mlflow
|
||||||
from src.cloud.mlflow import MlflowTrackingBackend, mlflow_tracking_backend_from_config
|
|
||||||
from src.config import Config, MlflowMode
|
from src.config import Config, MlflowMode
|
||||||
from src.tracking.metrics import METRICS_ARTIFACT_NAME, parse_training_metrics, read_training_metrics_from_tar
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class FinalizeResult:
|
|
||||||
registered_model_version: str | None = None
|
|
||||||
warnings: tuple[str, ...] = ()
|
|
||||||
|
|
||||||
|
|
||||||
class Tracker(Protocol):
|
class Tracker(Protocol):
|
||||||
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: ...
|
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: ...
|
||||||
|
|
||||||
def finalize_training_run(
|
def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None: ...
|
||||||
self,
|
|
||||||
*,
|
|
||||||
run_id: str | None,
|
|
||||||
training_job_status: Any,
|
|
||||||
region: str,
|
|
||||||
profile: str,
|
|
||||||
command: str,
|
|
||||||
) -> FinalizeResult: ...
|
|
||||||
|
|
||||||
def ensure_training_run(self, job_name: str) -> str: ...
|
|
||||||
|
|
||||||
def upload_training_metrics(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
run_id: str,
|
|
||||||
training_job_status: Any,
|
|
||||||
region: str,
|
|
||||||
profile: str,
|
|
||||||
) -> bool: ...
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -48,29 +20,8 @@ class NoopTracker:
|
|||||||
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
|
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def finalize_training_run(
|
def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None:
|
||||||
self,
|
return None
|
||||||
*,
|
|
||||||
run_id: str | None,
|
|
||||||
training_job_status: Any,
|
|
||||||
region: str,
|
|
||||||
profile: str,
|
|
||||||
command: str,
|
|
||||||
) -> FinalizeResult:
|
|
||||||
return FinalizeResult()
|
|
||||||
|
|
||||||
def ensure_training_run(self, job_name: str) -> str:
|
|
||||||
raise RuntimeError("MLflow is disabled.")
|
|
||||||
|
|
||||||
def upload_training_metrics(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
run_id: str,
|
|
||||||
training_job_status: Any,
|
|
||||||
region: str,
|
|
||||||
profile: str,
|
|
||||||
) -> bool:
|
|
||||||
raise RuntimeError("MLflow is disabled.")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -79,7 +30,6 @@ class MlflowTracker:
|
|||||||
experiment_name: str
|
experiment_name: str
|
||||||
registered_model_name: str
|
registered_model_name: str
|
||||||
register_trained_models: bool
|
register_trained_models: bool
|
||||||
tracking_backend: MlflowTrackingBackend
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, cfg: Config) -> Tracker:
|
def from_config(cls, cfg: Config) -> Tracker:
|
||||||
@@ -92,138 +42,94 @@ class MlflowTracker:
|
|||||||
if not tracking_server_name:
|
if not tracking_server_name:
|
||||||
raise RuntimeError("MLflow tracking server name could not be resolved.")
|
raise RuntimeError("MLflow tracking server name could not be resolved.")
|
||||||
|
|
||||||
tracking_backend = mlflow_tracking_backend_from_config(cfg)
|
tracking_uri = aws_mlflow.get_tracking_server_arn(
|
||||||
|
cfg.aws.region,
|
||||||
tracking_uri = tracking_backend.get_tracking_uri(tracking_server_name)
|
cfg.aws.profile,
|
||||||
with tracking_backend.auth_env():
|
tracking_server_name,
|
||||||
mlflow.set_tracking_uri(tracking_uri)
|
)
|
||||||
mlflow.set_experiment(cfg.mlflow.experiment_name)
|
mlflow.set_tracking_uri(tracking_uri)
|
||||||
|
mlflow.set_experiment(cfg.mlflow.experiment_name)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
tracking_uri=tracking_uri,
|
tracking_uri=tracking_uri,
|
||||||
experiment_name=cfg.mlflow.experiment_name,
|
experiment_name=cfg.mlflow.experiment_name,
|
||||||
registered_model_name=cfg.mlflow.registered_model_name,
|
registered_model_name=cfg.mlflow.registered_model_name,
|
||||||
register_trained_models=cfg.mlflow.register_trained_models,
|
register_trained_models=cfg.mlflow.register_trained_models,
|
||||||
tracking_backend=tracking_backend,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
|
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
|
||||||
with self.tracking_backend.auth_env():
|
run = mlflow.start_run(run_name=training_job.job_name)
|
||||||
with mlflow.start_run(run_name=training_job.job_name) as run:
|
run_id = str(run.info.run_id)
|
||||||
run_id = str(run.info.run_id)
|
|
||||||
self._log_params(
|
|
||||||
self.tracking_backend.training_run_params(
|
|
||||||
training_job,
|
|
||||||
region=region,
|
|
||||||
profile=profile,
|
|
||||||
role_arn=role_arn,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()})
|
|
||||||
mlflow.set_tags(
|
|
||||||
{
|
|
||||||
"qc_cli.stage": "experiment",
|
|
||||||
"qc_cli.artifact_kind": "trained_source",
|
|
||||||
"qc_cli.source": self.tracking_backend.provider_name,
|
|
||||||
"qc_cli.command": "train start",
|
|
||||||
**self.tracking_backend.training_run_tags(training_job),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return run_id
|
|
||||||
|
|
||||||
def finalize_training_run(
|
params = {
|
||||||
self,
|
"aws.region": region,
|
||||||
*,
|
"aws.profile": profile,
|
||||||
run_id: str | None,
|
"sagemaker.role_arn": role_arn,
|
||||||
training_job_status: Any,
|
"sagemaker.job_name": training_job.job_name,
|
||||||
region: str,
|
"sagemaker.training_image": training_job.image_uri,
|
||||||
profile: str,
|
"sagemaker.instance_type": training_job.instance_type,
|
||||||
command: str,
|
"sagemaker.instance_count": training_job.instance_count,
|
||||||
) -> FinalizeResult:
|
"sagemaker.s3_train_uri": training_job.s3_train_uri,
|
||||||
|
"sagemaker.s3_output_path": training_job.s3_output_path,
|
||||||
|
"sagemaker.entry_point": training_job.entry_point,
|
||||||
|
"sagemaker.source_dir": training_job.source_dir,
|
||||||
|
}
|
||||||
|
self._log_params(params)
|
||||||
|
self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()})
|
||||||
|
mlflow.set_tags(
|
||||||
|
{
|
||||||
|
"qc_cli.stage": "experiment",
|
||||||
|
"qc_cli.artifact_kind": "trained_source",
|
||||||
|
"qc_cli.source": "sagemaker",
|
||||||
|
"qc_cli.command": "train start",
|
||||||
|
"sagemaker.job_name": training_job.job_name,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mlflow.end_run()
|
||||||
|
return run_id
|
||||||
|
|
||||||
|
def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None:
|
||||||
if not run_id:
|
if not run_id:
|
||||||
return FinalizeResult()
|
return None
|
||||||
|
|
||||||
with self.tracking_backend.auth_env():
|
with mlflow.start_run(run_id=run_id):
|
||||||
with mlflow.start_run(run_id=run_id):
|
self._log_params(
|
||||||
self._log_params(self.tracking_backend.training_status_params(training_job_status))
|
{
|
||||||
self._log_final_metrics(training_job_status.raw)
|
"sagemaker.training_status": training_job_status.status,
|
||||||
mlflow.set_tag("qc_cli.command", command)
|
"sagemaker.created_at": training_job_status.created,
|
||||||
|
"sagemaker.modified_at": training_job_status.modified,
|
||||||
|
"sagemaker.model_artifacts": training_job_status.model_artifacts,
|
||||||
|
"sagemaker.failure_reason": training_job_status.failure_reason,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self._log_final_metrics(training_job_status.raw)
|
||||||
|
mlflow.set_tag("qc_cli.command", "train status")
|
||||||
|
|
||||||
if training_job_status.status != "Completed" or not training_job_status.model_artifacts:
|
if training_job_status.status != "Completed" or not training_job_status.model_artifacts:
|
||||||
mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status)
|
mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status)
|
||||||
return FinalizeResult()
|
return None
|
||||||
|
|
||||||
if not self.register_trained_models:
|
if not self.register_trained_models:
|
||||||
return FinalizeResult()
|
return None
|
||||||
|
|
||||||
client = MlflowClient()
|
|
||||||
self._ensure_registered_model(client, self.registered_model_name)
|
|
||||||
version = client.create_model_version(
|
|
||||||
name=self.registered_model_name,
|
|
||||||
source=training_job_status.model_artifacts,
|
|
||||||
run_id=run_id,
|
|
||||||
tags={
|
|
||||||
"qc_cli.stage": "experiment",
|
|
||||||
"qc_cli.artifact_kind": "trained_source",
|
|
||||||
"qc_cli.source": self.tracking_backend.provider_name,
|
|
||||||
**self.tracking_backend.model_version_tags(training_job_status),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
version_number = str(version.version)
|
|
||||||
client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number)
|
|
||||||
mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name)
|
|
||||||
mlflow.set_tag("qc_cli.registered_model_version", version_number)
|
|
||||||
return FinalizeResult(registered_model_version=version_number)
|
|
||||||
|
|
||||||
def ensure_training_run(self, job_name: str) -> str:
|
|
||||||
with self.tracking_backend.auth_env():
|
|
||||||
client = MlflowClient()
|
client = MlflowClient()
|
||||||
experiment = client.get_experiment_by_name(self.experiment_name)
|
self._ensure_registered_model(client, self.registered_model_name)
|
||||||
if experiment is None:
|
version = client.create_model_version(
|
||||||
experiment_id = mlflow.create_experiment(self.experiment_name)
|
name=self.registered_model_name,
|
||||||
else:
|
source=training_job_status.model_artifacts,
|
||||||
experiment_id = experiment.experiment_id
|
run_id=run_id,
|
||||||
|
|
||||||
for run in client.search_runs([experiment_id], max_results=1000):
|
|
||||||
if run.data.tags.get("sagemaker.job_name") == job_name:
|
|
||||||
return str(run.info.run_id)
|
|
||||||
|
|
||||||
run = client.create_run(
|
|
||||||
experiment_id,
|
|
||||||
run_name=job_name,
|
|
||||||
tags={
|
tags={
|
||||||
"qc_cli.stage": "experiment",
|
"qc_cli.stage": "experiment",
|
||||||
"qc_cli.artifact_kind": "trained_source",
|
"qc_cli.artifact_kind": "trained_source",
|
||||||
"qc_cli.source": self.tracking_backend.provider_name,
|
"qc_cli.source": "sagemaker",
|
||||||
"qc_cli.command": "mlflow upload-metrics",
|
"sagemaker.job_name": training_job_status.name,
|
||||||
"sagemaker.job_name": job_name,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return str(run.info.run_id)
|
version_number = str(version.version)
|
||||||
|
client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number)
|
||||||
def upload_training_metrics(
|
mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name)
|
||||||
self,
|
mlflow.set_tag("qc_cli.registered_model_version", version_number)
|
||||||
*,
|
return version_number
|
||||||
run_id: str,
|
|
||||||
training_job_status: Any,
|
|
||||||
region: str,
|
|
||||||
profile: str,
|
|
||||||
) -> bool:
|
|
||||||
if not training_job_status.model_artifacts:
|
|
||||||
raise ValueError(f"Training job '{training_job_status.name}' has no model artifacts.")
|
|
||||||
|
|
||||||
with self.tracking_backend.auth_env():
|
|
||||||
with mlflow.start_run(run_id=run_id):
|
|
||||||
self._log_params(self.tracking_backend.training_status_params(training_job_status))
|
|
||||||
self._log_final_metrics(training_job_status.raw)
|
|
||||||
history_uploaded = self._log_training_metrics(
|
|
||||||
training_job_status.model_artifacts,
|
|
||||||
region=region,
|
|
||||||
profile=profile,
|
|
||||||
)
|
|
||||||
mlflow.set_tag("qc_cli.command", "mlflow upload-metrics")
|
|
||||||
mlflow.set_tag("qc_cli.metrics_history_uploaded", str(history_uploaded).lower())
|
|
||||||
return history_uploaded
|
|
||||||
|
|
||||||
def _log_params(self, params: dict[str, Any]) -> None:
|
def _log_params(self, params: dict[str, Any]) -> None:
|
||||||
cleaned = {key: str(value) for key, value in params.items() if value is not None}
|
cleaned = {key: str(value) for key, value in params.items() if value is not None}
|
||||||
@@ -240,26 +146,6 @@ class MlflowTracker:
|
|||||||
if metrics:
|
if metrics:
|
||||||
mlflow.log_metrics(metrics)
|
mlflow.log_metrics(metrics)
|
||||||
|
|
||||||
def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> bool:
|
|
||||||
with tempfile.TemporaryDirectory(prefix="qc-cli-metrics-") as temp_dir:
|
|
||||||
archive_path = s3.download_file(
|
|
||||||
region,
|
|
||||||
profile,
|
|
||||||
model_artifacts,
|
|
||||||
os.path.join(temp_dir, "model.tar.gz"),
|
|
||||||
)
|
|
||||||
metrics_data = read_training_metrics_from_tar(archive_path)
|
|
||||||
if metrics_data is None:
|
|
||||||
return False
|
|
||||||
metrics = parse_training_metrics(metrics_data)
|
|
||||||
for metric_step in metrics.steps:
|
|
||||||
if metric_step.metrics:
|
|
||||||
mlflow.log_metrics(metric_step.metrics, step=metric_step.step)
|
|
||||||
if metrics.summary:
|
|
||||||
mlflow.log_metrics(metrics.summary)
|
|
||||||
mlflow.log_dict(metrics.raw, METRICS_ARTIFACT_NAME)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _ensure_registered_model(self, client: MlflowClient, name: str) -> None:
|
def _ensure_registered_model(self, client: MlflowClient, name: str) -> None:
|
||||||
try:
|
try:
|
||||||
client.get_registered_model(name)
|
client.get_registered_model(name)
|
||||||
|
|||||||
@@ -1,75 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
from src import state as state_ops
|
|
||||||
from src.aws import sagemaker as sm_ops
|
|
||||||
from src.config import Config, MlflowMode
|
|
||||||
from src.tracking.mlflow import MlflowTracker
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class MetricsUploadResult:
|
|
||||||
run_id: str
|
|
||||||
registered_model_version: str | None = None
|
|
||||||
metrics_history_uploaded: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
def upload_training_metrics(
|
|
||||||
*,
|
|
||||||
job_name: str,
|
|
||||||
config_path: str,
|
|
||||||
cfg: Config,
|
|
||||||
force: bool = False,
|
|
||||||
) -> MetricsUploadResult:
|
|
||||||
if cfg.mlflow.mode is MlflowMode.disabled:
|
|
||||||
raise RuntimeError("MLflow is disabled in config.yaml.")
|
|
||||||
|
|
||||||
st = state_ops.store(config_path)
|
|
||||||
job_state = st.get_training_job(job_name)
|
|
||||||
if job_state.get("mlflow_metrics_uploaded") and not force:
|
|
||||||
return MetricsUploadResult(
|
|
||||||
run_id=str(job_state.get("mlflow_run_id") or ""),
|
|
||||||
registered_model_version=(
|
|
||||||
str(job_state["registered_model_version"])
|
|
||||||
if job_state.get("registered_model_version")
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
metrics_history_uploaded=bool(job_state.get("mlflow_metrics_history_uploaded", True)),
|
|
||||||
)
|
|
||||||
|
|
||||||
status = sm_ops.get_training_job_status(cfg.aws.boto3_session, job_name)
|
|
||||||
if status.status != "Completed":
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Training job {job_name} is {status.status}; metrics can be uploaded only after completion."
|
|
||||||
)
|
|
||||||
|
|
||||||
tracker = MlflowTracker.from_config(cfg)
|
|
||||||
run_id = str(job_state.get("mlflow_run_id") or tracker.ensure_training_run(job_name))
|
|
||||||
st.update_training_job(job_name, mlflow_run_id=run_id)
|
|
||||||
metrics_history_uploaded = tracker.upload_training_metrics(
|
|
||||||
run_id=run_id,
|
|
||||||
training_job_status=status,
|
|
||||||
region=cfg.aws.region,
|
|
||||||
profile=cfg.aws.profile,
|
|
||||||
)
|
|
||||||
finalized = tracker.finalize_training_run(
|
|
||||||
run_id=run_id,
|
|
||||||
training_job_status=status,
|
|
||||||
region=cfg.aws.region,
|
|
||||||
profile=cfg.aws.profile,
|
|
||||||
command="mlflow upload-metrics",
|
|
||||||
)
|
|
||||||
updates = {
|
|
||||||
"mlflow_metrics_uploaded": True,
|
|
||||||
"mlflow_metrics_history_uploaded": metrics_history_uploaded,
|
|
||||||
"mlflow_finalized_status": status.status,
|
|
||||||
}
|
|
||||||
if finalized.registered_model_version:
|
|
||||||
updates["registered_model_version"] = finalized.registered_model_version
|
|
||||||
st.update_training_job(job_name, **updates)
|
|
||||||
if finalized.registered_model_version:
|
|
||||||
st.set_latest_experiment_model_version(finalized.registered_model_version)
|
|
||||||
return MetricsUploadResult(
|
|
||||||
run_id=run_id,
|
|
||||||
registered_model_version=finalized.registered_model_version,
|
|
||||||
metrics_history_uploaded=metrics_history_uploaded,
|
|
||||||
)
|
|
||||||
50
uv.lock
generated
50
uv.lock
generated
@@ -1003,6 +1003,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/8a/db/55a262f3606bebcae07cc14095338471ad7c0bbcaa37707e6f0ee49725b7/importlib_resources-7.1.0-py3-none-any.whl", hash = "sha256:1bd7b48b4088eddb2cd16382150bb515af0bd2c70128194392725f82ad2c96a1", size = 37232, upload-time = "2026-04-12T16:36:08.219Z" },
|
{ url = "https://files.pythonhosted.org/packages/8a/db/55a262f3606bebcae07cc14095338471ad7c0bbcaa37707e6f0ee49725b7/importlib_resources-7.1.0-py3-none-any.whl", hash = "sha256:1bd7b48b4088eddb2cd16382150bb515af0bd2c70128194392725f82ad2c96a1", size = 37232, upload-time = "2026-04-12T16:36:08.219Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "iniconfig"
|
||||||
|
version = "2.3.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "itsdangerous"
|
name = "itsdangerous"
|
||||||
version = "2.2.0"
|
version = "2.2.0"
|
||||||
@@ -1665,6 +1674,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/ff/6e/cf826fae916b8658848d7b9f38d88da6396895c676e8086fc0988073aaf8/pillow-12.2.0-cp314-cp314t-win_arm64.whl", hash = "sha256:aa88ccfe4e32d362816319ed727a004423aab09c5cea43c01a4b435643fa34eb", size = 2556579, upload-time = "2026-04-01T14:45:52.529Z" },
|
{ url = "https://files.pythonhosted.org/packages/ff/6e/cf826fae916b8658848d7b9f38d88da6396895c676e8086fc0988073aaf8/pillow-12.2.0-cp314-cp314t-win_arm64.whl", hash = "sha256:aa88ccfe4e32d362816319ed727a004423aab09c5cea43c01a4b435643fa34eb", size = 2556579, upload-time = "2026-04-01T14:45:52.529Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pluggy"
|
||||||
|
version = "1.6.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "prettytable"
|
name = "prettytable"
|
||||||
version = "3.17.0"
|
version = "3.17.0"
|
||||||
@@ -1945,6 +1963,34 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/16/6b/330d8ebae582b30c2959a1ef4c3bc344ebde48c2ff0c3f113c4710735e11/pyright-1.1.409-py3-none-any.whl", hash = "sha256:aa3ea228cab90c845c7a60d28db7a844c04315356392aa09fafcee98c8c22fb3", size = 6438161, upload-time = "2026-04-23T11:02:01.309Z" },
|
{ url = "https://files.pythonhosted.org/packages/16/6b/330d8ebae582b30c2959a1ef4c3bc344ebde48c2ff0c3f113c4710735e11/pyright-1.1.409-py3-none-any.whl", hash = "sha256:aa3ea228cab90c845c7a60d28db7a844c04315356392aa09fafcee98c8c22fb3", size = 6438161, upload-time = "2026-04-23T11:02:01.309Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pytest"
|
||||||
|
version = "9.0.3"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||||
|
{ name = "iniconfig" },
|
||||||
|
{ name = "packaging" },
|
||||||
|
{ name = "pluggy" },
|
||||||
|
{ name = "pygments" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/7d/0d/549bd94f1a0a402dc8cf64563a117c0f3765662e2e668477624baeec44d5/pytest-9.0.3.tar.gz", hash = "sha256:b86ada508af81d19edeb213c681b1d48246c1a91d304c6c81a427674c17eb91c", size = 1572165, upload-time = "2026-04-07T17:16:18.027Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d4/24/a372aaf5c9b7208e7112038812994107bc65a84cd00e0354a88c2c77a617/pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9", size = 375249, upload-time = "2026-04-07T17:16:16.13Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pytest-mock"
|
||||||
|
version = "3.15.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "pytest" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/68/14/eb014d26be205d38ad5ad20d9a80f7d201472e08167f0bb4361e251084a9/pytest_mock-3.15.1.tar.gz", hash = "sha256:1849a238f6f396da19762269de72cb1814ab44416fa73a8686deac10b0d87a0f", size = 34036, upload-time = "2025-09-16T16:37:27.081Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/5a/cc/06253936f4a7fa2e0f48dfe6d851d9c56df896a9ab09ac019d70b760619c/pytest_mock-3.15.1-py3-none-any.whl", hash = "sha256:0a25e2eb88fe5168d535041d09a4529a188176ae608a6d249ee65abc0949630d", size = 10095, upload-time = "2025-09-16T16:37:25.734Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "python-dateutil"
|
name = "python-dateutil"
|
||||||
version = "2.9.0.post0"
|
version = "2.9.0.post0"
|
||||||
@@ -2068,6 +2114,8 @@ dependencies = [
|
|||||||
dev = [
|
dev = [
|
||||||
{ name = "boto3-stubs", extra = ["iam", "s3", "sagemaker"] },
|
{ name = "boto3-stubs", extra = ["iam", "s3", "sagemaker"] },
|
||||||
{ name = "pyright" },
|
{ name = "pyright" },
|
||||||
|
{ name = "pytest" },
|
||||||
|
{ name = "pytest-mock" },
|
||||||
{ name = "ruff" },
|
{ name = "ruff" },
|
||||||
{ name = "types-pyyaml" },
|
{ name = "types-pyyaml" },
|
||||||
]
|
]
|
||||||
@@ -2090,6 +2138,8 @@ requires-dist = [
|
|||||||
dev = [
|
dev = [
|
||||||
{ name = "boto3-stubs", extras = ["iam", "s3", "sagemaker"] },
|
{ name = "boto3-stubs", extras = ["iam", "s3", "sagemaker"] },
|
||||||
{ name = "pyright", specifier = ">=1.1.409" },
|
{ name = "pyright", specifier = ">=1.1.409" },
|
||||||
|
{ name = "pytest", specifier = ">=8.0" },
|
||||||
|
{ name = "pytest-mock", specifier = ">=3.12" },
|
||||||
{ name = "ruff", specifier = ">=0.4" },
|
{ name = "ruff", specifier = ">=0.4" },
|
||||||
{ name = "types-pyyaml" },
|
{ name = "types-pyyaml" },
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user