2 Commits

Author SHA1 Message Date
samirodr
5360a482fc update 2026-06-08 14:59:44 -04:00
samirodr
6a560a8610 match 2026-06-08 14:54:13 -04:00
9 changed files with 93 additions and 394 deletions

View File

@@ -13,18 +13,10 @@ This example takes the ONNX model produced by the SageMaker training example and
Run the training example first and wait for it to complete: Run the training example first and wait for it to complete:
```bash ```bash
bash examples/training/run_training.sh --config config.yaml --wait examples/training/run_training.sh --wait
``` ```
If the dataset is already uploaded to S3, use: The `config.yaml` file must include AI Hub settings:
```bash
bash examples/training/run_training.sh --config config.yaml --skip-upload --wait
```
The training artifact must contain a static-shape `model.onnx`. The training example exports an input named `input` with shape `1x3x160x160`.
Your `config.yaml` must include AI Hub settings:
```yaml ```yaml
aihub: aihub:
@@ -36,16 +28,20 @@ aihub:
output_dir: build/qai-hub output_dir: build/qai-hub
``` ```
You also need local Qualcomm AI Hub SDK authentication configured. Finally, the user needs to authenticate with Qualcomm AI Hub using:
```bash
qai-hub configure --api_token
```
## Prepare Inputs ## Prepare Inputs
AI Hub does not consume the raw JPG training images directly. It needs NumPy tensors that match the ONNX model input shape and preprocessing. AI Hub does not consume the raw JPG training images directly. It needs NumPy tensors that match the ONNX model input shape and preprocessing.
Generate calibration and validation inputs: To generate calibration and validation inputs:
```bash ```bash
uv run python examples/ai-hub/prepare_inputs.py python examples/ai-hub/prepare_inputs.py
``` ```
This writes: This writes:
@@ -61,58 +57,23 @@ The script applies the same image preprocessing used by the training example:
- convert to channel-first `1x3x160x160` - convert to channel-first `1x3x160x160`
- normalize with ImageNet mean and standard deviation - normalize with ImageNet mean and standard deviation
Useful options: ## Upload Model to Qualcomm Workbench
The model can be uploaded to Qualcomm Workbench using:
```bash ```bash
uv run python examples/ai-hub/prepare_inputs.py \ qc-cli ai-hub upload examples/training/data/aihub_calibration examples/training/data/inputs.npz
--dataset-dir examples/training/data/flower_photos_sagemaker \
--calibration-dir examples/training/data/aihub_calibration \
--input-file examples/training/data/inputs.npz \
--samples 16
``` ```
## Run AI Hub The first argument is the calibration path for the model and the second argument is the input file, both of which were created by the `prepare_inputs.py` script. For more details, add `--help` after the `upload` command.
After training completes and inputs are prepared: The `upload` command runs the following commands in order:
1. `qc-cli ai-hub quantize`
2. `qc-cli ai-hub compile`
3. `qc-cli ai-hub validate`
4. `qc-cli ai-hub profile`
Finally the user can download the model from AI Workbench using the command
```bash ```bash
bash examples/ai-hub/run_ai_hub.sh --config config.yaml qc-cli ai-hub download
``` ```
By default, the script uses the last SageMaker training job recorded in `.qc-cli.json`. It downloads that job's `model.tar.gz`, extracts `model.onnx`, runs the AI Hub workflow, and downloads the compiled artifact.
To use a specific training job:
```bash
bash examples/ai-hub/run_ai_hub.sh \
--config config.yaml \
--from-job qc-cli-YYYYMMDD-HHMMSS
```
To resume from a later Workbench step:
```bash
bash examples/ai-hub/run_ai_hub.sh \
--config config.yaml \
--from-step validate
```
To skip downloading the compiled artifact:
```bash
bash examples/ai-hub/run_ai_hub.sh \
--config config.yaml \
--skip-download
```
## Troubleshooting
If AI Hub reports dynamic input shapes, rerun training with the current training source. AI Hub quantization requires the exported ONNX model to use static input shapes.
If `run_ai_hub.sh` reports missing calibration or input files, run:
```bash
uv run python examples/ai-hub/prepare_inputs.py
```
If validation fails with a missing input name, make sure `config.yaml` and the generated `.npz` both use `input` as the input name.

0
examples/ai-hub/prepare_inputs.py Executable file → Normal file
View File

View File

@@ -1,156 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
CONFIG_PATH="config.yaml"
CALIBRATION_PATH="examples/training/data/aihub_calibration"
INPUT_FILE="examples/training/data/inputs.npz"
FROM_STEP="quantize"
FROM_JOB=""
MODEL_S3_URI=""
ONNX_PATH=""
INPUT_NAME=""
DOWNLOAD=true
OUTPUT_PATH=""
usage() {
cat <<EOF
Usage: $0 [options]
Options:
--config PATH Path to qc-cli config file. Default: config.yaml
--calibration PATH Calibration .npz file or directory of .npy samples.
Default: ${CALIBRATION_PATH}
--input-file PATH Validation .npz or .npy inputs. Default: ${INPUT_FILE}
--from-step STEP Resume upload from: quantize, compile, validate, profile.
Default: ${FROM_STEP}
--from-job NAME SageMaker training job whose model artifact should upload.
Defaults to the last training job in local qc-cli state.
--model-s3-uri URI S3 URI of model.tar.gz to upload.
--onnx-path PATH Local ONNX path or ONNX path inside extracted artifact.
--input-name NAME Input name for .npy validation files.
--skip-download Do not download the compiled AI Hub artifact after upload.
--output PATH Destination file for ai-hub download.
-h, --help Show this help.
EOF
}
while [[ $# -gt 0 ]]; do
case "$1" in
--config)
CONFIG_PATH="$2"
shift 2
;;
--calibration)
CALIBRATION_PATH="$2"
shift 2
;;
--input-file)
INPUT_FILE="$2"
shift 2
;;
--from-step)
FROM_STEP="$2"
shift 2
;;
--from-job)
FROM_JOB="$2"
shift 2
;;
--model-s3-uri)
MODEL_S3_URI="$2"
shift 2
;;
--onnx-path)
ONNX_PATH="$2"
shift 2
;;
--input-name)
INPUT_NAME="$2"
shift 2
;;
--skip-download)
DOWNLOAD=false
shift
;;
--output)
OUTPUT_PATH="$2"
shift 2
;;
-h|--help)
usage
exit 0
;;
*)
echo "Unknown option: $1" >&2
usage >&2
exit 1
;;
esac
done
if [[ ! -f "${CONFIG_PATH}" ]]; then
echo "Config not found: ${CONFIG_PATH}" >&2
exit 1
fi
case "${FROM_STEP}" in
quantize|compile|validate|profile)
;;
*)
echo "--from-step must be one of: quantize, compile, validate, profile" >&2
exit 1
;;
esac
if [[ ! -e "${CALIBRATION_PATH}" ]]; then
echo "Calibration path not found: ${CALIBRATION_PATH}" >&2
echo "Pass --calibration with a .npz file or directory of .npy samples." >&2
exit 1
fi
if [[ ! -f "${INPUT_FILE}" ]]; then
echo "Input file not found: ${INPUT_FILE}" >&2
echo "Pass --input-file with a validation .npz or .npy file." >&2
exit 1
fi
run() {
echo "+ $*"
"$@"
}
UPLOAD_ARGS=(
"${CALIBRATION_PATH}"
"${INPUT_FILE}"
--from-step "${FROM_STEP}"
--config "${CONFIG_PATH}"
)
if [[ -n "${FROM_JOB}" ]]; then
UPLOAD_ARGS+=(--from-job "${FROM_JOB}")
fi
if [[ -n "${MODEL_S3_URI}" ]]; then
UPLOAD_ARGS+=(--model-s3-uri "${MODEL_S3_URI}")
fi
if [[ -n "${ONNX_PATH}" ]]; then
UPLOAD_ARGS+=(--onnx-path "${ONNX_PATH}")
fi
if [[ -n "${INPUT_NAME}" ]]; then
UPLOAD_ARGS+=(--input-name "${INPUT_NAME}")
fi
run uv run qc-cli ai-hub upload "${UPLOAD_ARGS[@]}"
if [[ "${DOWNLOAD}" == false ]]; then
exit 0
fi
DOWNLOAD_ARGS=(--config "${CONFIG_PATH}")
if [[ -n "${OUTPUT_PATH}" ]]; then
DOWNLOAD_ARGS+=(--output "${OUTPUT_PATH}")
fi
run uv run qc-cli ai-hub download "${DOWNLOAD_ARGS[@]}"

View File

@@ -1,6 +1,3 @@
import os
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, cast from typing import Any, cast
import boto3 import boto3
@@ -37,38 +34,3 @@ def create_presigned_tracking_server_url(region: str, profile: str, name: str) -
client = boto3.Session(profile_name=profile, region_name=region).client("sagemaker") client = boto3.Session(profile_name=profile, region_name=region).client("sagemaker")
response = client.create_presigned_mlflow_tracking_server_url(TrackingServerName=name) response = client.create_presigned_mlflow_tracking_server_url(TrackingServerName=name)
return str(response["AuthorizedUrl"]) return str(response["AuthorizedUrl"])
@contextmanager
def tracking_auth_env(profile: str, region: str) -> Generator[None]:
credentials = boto3.Session(profile_name=profile, region_name=region).get_credentials()
if credentials is None:
raise RuntimeError(f"AWS credentials could not be resolved for profile '{profile}'.")
frozen_credentials = credentials.get_frozen_credentials()
if not frozen_credentials.access_key or not frozen_credentials.secret_key:
raise RuntimeError(f"AWS credentials are incomplete for profile '{profile}'.")
env_updates = {
"AWS_PROFILE": profile,
"AWS_DEFAULT_REGION": region,
"AWS_REGION": region,
"AWS_ACCESS_KEY_ID": frozen_credentials.access_key,
"AWS_SECRET_ACCESS_KEY": frozen_credentials.secret_key,
}
if frozen_credentials.token:
env_updates["AWS_SESSION_TOKEN"] = frozen_credentials.token
restore_keys = set(env_updates) | {"AWS_SESSION_TOKEN"}
previous_env = {key: os.environ.get(key) for key in restore_keys}
try:
os.environ.update(env_updates)
if not frozen_credentials.token:
os.environ.pop("AWS_SESSION_TOKEN", None)
yield
finally:
for key, value in previous_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value

View File

@@ -1 +0,0 @@
"""Cloud provider adapters."""

View File

@@ -1,77 +0,0 @@
from contextlib import AbstractContextManager
from dataclasses import dataclass
from typing import Any, Protocol
from src.aws import mlflow as aws_mlflow
from src.config import Config
class MlflowTrackingBackend(Protocol):
@property
def provider_name(self) -> str: ...
@property
def profile(self) -> str: ...
@property
def region(self) -> str: ...
def get_tracking_uri(self, tracking_server_name: str) -> str: ...
def auth_env(self) -> AbstractContextManager[None]: ...
def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]: ...
def training_run_tags(self, training_job: Any) -> dict[str, Any]: ...
def training_status_params(self, training_job_status: Any) -> dict[str, Any]: ...
def model_version_tags(self, training_job_status: Any) -> dict[str, Any]: ...
@dataclass(frozen=True)
class AwsMlflowTrackingBackend:
profile: str
region: str
provider_name: str = "aws"
def get_tracking_uri(self, tracking_server_name: str) -> str:
return aws_mlflow.get_tracking_server_arn(self.region, self.profile, tracking_server_name)
def auth_env(self) -> AbstractContextManager[None]:
return aws_mlflow.tracking_auth_env(self.profile, self.region)
def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]:
return {
"provider.name": self.provider_name,
"provider.region": region,
"provider.profile": profile,
"sagemaker.role_arn": role_arn,
"sagemaker.job_name": training_job.job_name,
"sagemaker.training_image": training_job.image_uri,
"sagemaker.instance_type": training_job.instance_type,
"sagemaker.instance_count": training_job.instance_count,
"sagemaker.s3_train_uri": training_job.s3_train_uri,
"sagemaker.s3_output_path": training_job.s3_output_path,
"sagemaker.entry_point": training_job.entry_point,
"sagemaker.source_dir": training_job.source_dir,
}
def training_run_tags(self, training_job: Any) -> dict[str, Any]:
return {"sagemaker.job_name": training_job.job_name}
def training_status_params(self, training_job_status: Any) -> dict[str, Any]:
return {
"sagemaker.training_status": training_job_status.status,
"sagemaker.created_at": training_job_status.created,
"sagemaker.modified_at": training_job_status.modified,
"sagemaker.model_artifacts": training_job_status.model_artifacts,
"sagemaker.failure_reason": training_job_status.failure_reason,
}
def model_version_tags(self, training_job_status: Any) -> dict[str, Any]:
return {"sagemaker.job_name": training_job_status.name}
def mlflow_tracking_backend_from_config(cfg: Config) -> MlflowTrackingBackend:
return AwsMlflowTrackingBackend(profile=cfg.aws.profile, region=cfg.aws.region)

View File

@@ -14,7 +14,7 @@ from src.config import Config
from src.qualcomm import aihub_jobs from src.qualcomm import aihub_jobs
from src.qualcomm.artifacts import resolve_onnx from src.qualcomm.artifacts import resolve_onnx
app = typer.Typer(help="Quantize, compile, validate, profile, and download models with Qualcomm AI Hub") app = typer.Typer(help="Quantize, compile, validate, profile, and download models with Qualcomm Workbench")
_RUNTIME_EXTENSIONS = { _RUNTIME_EXTENSIONS = {
"tflite": "tflite", "tflite": "tflite",

View File

@@ -1 +0,0 @@

View File

@@ -5,7 +5,7 @@ from typing import Any, Protocol
import mlflow import mlflow
from mlflow.tracking import MlflowClient from mlflow.tracking import MlflowClient
from src.cloud.mlflow import MlflowTrackingBackend, mlflow_tracking_backend_from_config from src.aws import mlflow as aws_mlflow
from src.config import Config, MlflowMode from src.config import Config, MlflowMode
@@ -30,7 +30,6 @@ class MlflowTracker:
experiment_name: str experiment_name: str
registered_model_name: str registered_model_name: str
register_trained_models: bool register_trained_models: bool
tracking_backend: MlflowTrackingBackend
@classmethod @classmethod
def from_config(cls, cfg: Config) -> Tracker: def from_config(cls, cfg: Config) -> Tracker:
@@ -43,10 +42,11 @@ class MlflowTracker:
if not tracking_server_name: if not tracking_server_name:
raise RuntimeError("MLflow tracking server name could not be resolved.") raise RuntimeError("MLflow tracking server name could not be resolved.")
tracking_backend = mlflow_tracking_backend_from_config(cfg) tracking_uri = aws_mlflow.get_tracking_server_arn(
cfg.aws.region,
tracking_uri = tracking_backend.get_tracking_uri(tracking_server_name) cfg.aws.profile,
with tracking_backend.auth_env(): tracking_server_name,
)
mlflow.set_tracking_uri(tracking_uri) mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment(cfg.mlflow.experiment_name) mlflow.set_experiment(cfg.mlflow.experiment_name)
@@ -55,30 +55,34 @@ class MlflowTracker:
experiment_name=cfg.mlflow.experiment_name, experiment_name=cfg.mlflow.experiment_name,
registered_model_name=cfg.mlflow.registered_model_name, registered_model_name=cfg.mlflow.registered_model_name,
register_trained_models=cfg.mlflow.register_trained_models, register_trained_models=cfg.mlflow.register_trained_models,
tracking_backend=tracking_backend,
) )
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
with self.tracking_backend.auth_env():
run = mlflow.start_run(run_name=training_job.job_name) run = mlflow.start_run(run_name=training_job.job_name)
run_id = str(run.info.run_id) run_id = str(run.info.run_id)
self._log_params( params = {
self.tracking_backend.training_run_params( "aws.region": region,
training_job, "aws.profile": profile,
region=region, "sagemaker.role_arn": role_arn,
profile=profile, "sagemaker.job_name": training_job.job_name,
role_arn=role_arn, "sagemaker.training_image": training_job.image_uri,
) "sagemaker.instance_type": training_job.instance_type,
) "sagemaker.instance_count": training_job.instance_count,
"sagemaker.s3_train_uri": training_job.s3_train_uri,
"sagemaker.s3_output_path": training_job.s3_output_path,
"sagemaker.entry_point": training_job.entry_point,
"sagemaker.source_dir": training_job.source_dir,
}
self._log_params(params)
self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()}) self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()})
mlflow.set_tags( mlflow.set_tags(
{ {
"qc_cli.stage": "experiment", "qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source", "qc_cli.artifact_kind": "trained_source",
"qc_cli.source": self.tracking_backend.provider_name, "qc_cli.source": "sagemaker",
"qc_cli.command": "train start", "qc_cli.command": "train start",
**self.tracking_backend.training_run_tags(training_job), "sagemaker.job_name": training_job.job_name,
} }
) )
mlflow.end_run() mlflow.end_run()
@@ -88,9 +92,16 @@ class MlflowTracker:
if not run_id: if not run_id:
return None return None
with self.tracking_backend.auth_env():
with mlflow.start_run(run_id=run_id): with mlflow.start_run(run_id=run_id):
self._log_params(self.tracking_backend.training_status_params(training_job_status)) self._log_params(
{
"sagemaker.training_status": training_job_status.status,
"sagemaker.created_at": training_job_status.created,
"sagemaker.modified_at": training_job_status.modified,
"sagemaker.model_artifacts": training_job_status.model_artifacts,
"sagemaker.failure_reason": training_job_status.failure_reason,
}
)
self._log_final_metrics(training_job_status.raw) self._log_final_metrics(training_job_status.raw)
mlflow.set_tag("qc_cli.command", "train status") mlflow.set_tag("qc_cli.command", "train status")
@@ -110,8 +121,8 @@ class MlflowTracker:
tags={ tags={
"qc_cli.stage": "experiment", "qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source", "qc_cli.artifact_kind": "trained_source",
"qc_cli.source": self.tracking_backend.provider_name, "qc_cli.source": "sagemaker",
**self.tracking_backend.model_version_tags(training_job_status), "sagemaker.job_name": training_job_status.name,
}, },
) )
version_number = str(version.version) version_number = str(version.version)