11 Commits

Author SHA1 Message Date
d3ebd2cc5f inital ai hub implementation 2026-06-01 15:14:10 -04:00
57a8a0a9c4 rename and future steps 2026-05-29 15:40:38 -04:00
a43c792cfd reorg 2026-05-29 14:52:34 -04:00
cf6a561e2f clean 2026-05-29 14:36:57 -04:00
416e51901d space 2026-05-29 14:33:17 -04:00
556797cf13 remove 2026-05-29 14:31:36 -04:00
19fef8638b mlflow not being an optional lin 2026-05-29 14:29:05 -04:00
58681cef82 command to create presigned URL for MLFlow 2026-05-27 10:52:08 -04:00
e1c8d6574f omit server name when created with config 2026-05-27 10:23:53 -04:00
35d25d8967 Merge branch 'main' into ml-flow 2026-05-27 08:58:46 -04:00
b907a74525 wip mlflow implementation 2026-05-26 15:03:53 -04:00
20 changed files with 329 additions and 789 deletions

View File

@@ -67,8 +67,7 @@ sagemaker:
hyperparameters: {} hyperparameters: {}
aihub: aihub:
device: device: Samsung Galaxy S25 (Family)
name: Samsung Galaxy S25 (Family)
target_runtime: tflite target_runtime: tflite
input_specs: {} # Required before running qc-cli ai-hub commands input_specs: {} # Required before running qc-cli ai-hub commands
job_name: null # Optional prefix for AI Hub Workbench jobs job_name: null # Optional prefix for AI Hub Workbench jobs
@@ -110,10 +109,10 @@ When MLflow is enabled, `train start` creates an MLflow run for the SageMaker jo
To open the managed SageMaker MLflow UI, request a fresh presigned URL: To open the managed SageMaker MLflow UI, request a fresh presigned URL:
```bash ```bash
qc-cli mlflow open --config config.yaml qc-cli infra mlflow-url --config config.yaml
``` ```
This opens a browser to a fresh presigned URL. It works for `mode: create` and for `mode: existing` when the existing server is managed by Amazon SageMaker. In `create` mode, the command uses the CLI-managed tracking server name. In `existing` mode, it uses `mlflow.tracking_server_name`. If the existing MLflow server is external to SageMaker, open it with that server's own URL instead. This works for `mode: create` and for `mode: existing` when the existing server is managed by Amazon SageMaker. In `create` mode, the command uses the CLI-managed tracking server name. In `existing` mode, it uses `mlflow.tracking_server_name`. If the existing MLflow server is external to SageMaker, open it with that server's own URL instead.
## Commands ## Commands
@@ -125,12 +124,6 @@ qc-cli init --output <path> Write config to a custom path
qc-cli init --force Overwrite an existing config file qc-cli init --force Overwrite an existing config file
``` ```
### `mlflow`
```
qc-cli mlflow open Open a presigned MLflow UI URL in a browser
```
### `infra` ### `infra`
``` ```
@@ -138,6 +131,7 @@ qc-cli infra setup Deploy the CDK stack
qc-cli infra setup --no-bootstrap Deploy without running CDK bootstrap qc-cli infra setup --no-bootstrap Deploy without running CDK bootstrap
qc-cli infra setup --cloudformation-execution-policy <arn> Set CDK bootstrap execution policy ARN qc-cli infra setup --cloudformation-execution-policy <arn> Set CDK bootstrap execution policy ARN
qc-cli infra status Show CDK stack/resource status qc-cli infra status Show CDK stack/resource status
qc-cli infra mlflow-url Print a presigned MLflow UI URL
qc-cli infra destroy Destroy stack, retaining S3 data qc-cli infra destroy Destroy stack, retaining S3 data
qc-cli infra destroy --yes Destroy stack without confirmation qc-cli infra destroy --yes Destroy stack without confirmation
qc-cli infra destroy --delete-bucket-data Destroy stack and delete S3 data qc-cli infra destroy --delete-bucket-data Destroy stack and delete S3 data
@@ -186,17 +180,6 @@ qc-cli ai-hub download [--model-id ID] [--output PATH]
`ai-hub upload` runs the four Workbench upload steps in order: quantize, compile, validate, and profile. Use `--from-step compile`, `--from-step validate`, or `--from-step profile` to resume from saved local state after a completed earlier step. `ai-hub upload` runs the four Workbench upload steps in order: quantize, compile, validate, and profile. Use `--from-step compile`, `--from-step validate`, or `--from-step profile` to resume from saved local state after a completed earlier step.
Resume behavior:
```text
--from-step quantize Run quantize, compile, validate, and profile.
--from-step compile Skip quantize; compile the last quantized model unless an explicit source is passed.
--from-step validate Skip quantize and compile; validate the last compiled model.
--from-step profile Skip quantize, compile, and validate; profile the last compiled model.
```
When a step runs in the current command, `upload` passes its returned model ID directly to the next step. When a step is skipped, the next step resolves the needed model ID from `.qc-cli.json`. This avoids re-running earlier AI Hub jobs when you only need to continue from a later step.
`ai-hub compile` resolves model sources in this order: `--model-id`, explicit source options (`--onnx-path`, `--model-s3-uri`, `--from-job`), last quantized model from state, then the last training job from local state. `ai-hub download` is separate because downloading the optimized artifact is outside the four-step Workbench upload loop. `ai-hub compile` resolves model sources in this order: `--model-id`, explicit source options (`--onnx-path`, `--model-s3-uri`, `--from-job`), last quantized model from state, then the last training job from local state. `ai-hub download` is separate because downloading the optimized artifact is outside the four-step Workbench upload loop.
AI Hub authentication currently uses the local `qai-hub` SDK configuration. A planned follow-up is to support AWS Systems Manager Parameter Store `SecureString` for team-managed tokens, where `config.yaml` stores only a parameter name such as `/qc-cli/aihub/token`, AWS KMS encrypts the token at rest, and the CLI retrieves it at runtime with `ssm:GetParameter` plus `kms:Decrypt` permissions. AI Hub authentication currently uses the local `qai-hub` SDK configuration. A planned follow-up is to support AWS Systems Manager Parameter Store `SecureString` for team-managed tokens, where `config.yaml` stores only a parameter name such as `/qc-cli/aihub/token`, AWS KMS encrypts the token at rest, and the CLI retrieves it at runtime with `ssm:GetParameter` plus `kms:Decrypt` permissions.

View File

@@ -1,118 +0,0 @@
# Qualcomm AI Hub Example
This example takes the ONNX model produced by the SageMaker training example and runs the Qualcomm AI Hub upload workflow:
1. Quantize
2. Compile
3. Validate
4. Profile
5. Download the compiled artifact
## Prerequisites
Run the training example first and wait for it to complete:
```bash
bash examples/training/run_training.sh --config config.yaml --wait
```
If the dataset is already uploaded to S3, use:
```bash
bash examples/training/run_training.sh --config config.yaml --skip-upload --wait
```
The training artifact must contain a static-shape `model.onnx`. The training example exports an input named `input` with shape `1x3x160x160`.
Your `config.yaml` must include AI Hub settings:
```yaml
aihub:
device:
name: Samsung Galaxy S25 (Family)
target_runtime: tflite
input_specs:
input: [[1, 3, 160, 160], float32]
output_dir: build/qai-hub
```
You also need local Qualcomm AI Hub SDK authentication configured.
## Prepare Inputs
AI Hub does not consume the raw JPG training images directly. It needs NumPy tensors that match the ONNX model input shape and preprocessing.
Generate calibration and validation inputs:
```bash
uv run python examples/ai-hub/prepare_inputs.py
```
This writes:
```text
examples/training/data/aihub_calibration/*.npy
examples/training/data/inputs.npz
```
The script applies the same image preprocessing used by the training example:
- resize to `160x160`
- convert to channel-first `1x3x160x160`
- normalize with ImageNet mean and standard deviation
Useful options:
```bash
uv run python examples/ai-hub/prepare_inputs.py \
--dataset-dir examples/training/data/flower_photos_sagemaker \
--calibration-dir examples/training/data/aihub_calibration \
--input-file examples/training/data/inputs.npz \
--samples 16
```
## Run AI Hub
After training completes and inputs are prepared:
```bash
bash examples/ai-hub/run_ai_hub.sh --config config.yaml
```
By default, the script uses the last SageMaker training job recorded in `.qc-cli.json`. It downloads that job's `model.tar.gz`, extracts `model.onnx`, runs the AI Hub workflow, and downloads the compiled artifact.
To use a specific training job:
```bash
bash examples/ai-hub/run_ai_hub.sh \
--config config.yaml \
--from-job qc-cli-YYYYMMDD-HHMMSS
```
To resume from a later Workbench step:
```bash
bash examples/ai-hub/run_ai_hub.sh \
--config config.yaml \
--from-step validate
```
To skip downloading the compiled artifact:
```bash
bash examples/ai-hub/run_ai_hub.sh \
--config config.yaml \
--skip-download
```
## Troubleshooting
If AI Hub reports dynamic input shapes, rerun training with the current training source. AI Hub quantization requires the exported ONNX model to use static input shapes.
If `run_ai_hub.sh` reports missing calibration or input files, run:
```bash
uv run python examples/ai-hub/prepare_inputs.py
```
If validation fails with a missing input name, make sure `config.yaml` and the generated `.npz` both use `input` as the input name.

View File

@@ -1,74 +0,0 @@
#!/usr/bin/env python3
"""Prepare Qualcomm AI Hub calibration and validation inputs for the training example."""
from __future__ import annotations
import argparse
from pathlib import Path
import numpy as np
from PIL import Image
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png"}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--dataset-dir",
type=Path,
default=Path("examples/training/data/flower_photos_sagemaker"),
help="ImageFolder-style dataset used for training.",
)
parser.add_argument(
"--calibration-dir",
type=Path,
default=Path("examples/training/data/aihub_calibration"),
help="Directory where .npy calibration samples will be written.",
)
parser.add_argument(
"--input-file",
type=Path,
default=Path("examples/training/data/inputs.npz"),
help="Validation .npz input file for qc-cli ai-hub validate.",
)
parser.add_argument("--input-name", default="input", help="ONNX input name.")
parser.add_argument("--image-size", type=int, default=160, help="Square image size used by training.")
parser.add_argument("--samples", type=int, default=16, help="Number of calibration samples to write.")
return parser.parse_args()
def preprocess_image(path: Path, image_size: int) -> np.ndarray:
image = Image.open(path).convert("RGB").resize((image_size, image_size), Image.Resampling.BILINEAR)
array = np.asarray(image, dtype=np.float32) / 255.0
array = np.transpose(array, (2, 0, 1))
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)[:, None, None]
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)[:, None, None]
return ((array - mean) / std)[None, ...].astype("float32")
def main() -> None:
args = parse_args()
images = sorted(p for p in args.dataset_dir.rglob("*") if p.suffix.lower() in IMAGE_EXTENSIONS)
if not images:
raise SystemExit(f"No images found under {args.dataset_dir}")
if args.samples < 1:
raise SystemExit("--samples must be at least 1")
args.calibration_dir.mkdir(parents=True, exist_ok=True)
args.input_file.parent.mkdir(parents=True, exist_ok=True)
sample_count = min(args.samples, len(images))
prepared = []
for index, image_path in enumerate(images[:sample_count]):
sample = preprocess_image(image_path, args.image_size)
np.save(args.calibration_dir / f"sample_{index:03d}.npy", sample)
prepared.append(sample)
np.savez(args.input_file, **{args.input_name: prepared[0]})
print(f"Wrote {sample_count} calibration samples to {args.calibration_dir}")
print(f"Wrote validation input to {args.input_file}")
if __name__ == "__main__":
main()

View File

@@ -1,156 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
CONFIG_PATH="config.yaml"
CALIBRATION_PATH="examples/training/data/aihub_calibration"
INPUT_FILE="examples/training/data/inputs.npz"
FROM_STEP="quantize"
FROM_JOB=""
MODEL_S3_URI=""
ONNX_PATH=""
INPUT_NAME=""
DOWNLOAD=true
OUTPUT_PATH=""
usage() {
cat <<EOF
Usage: $0 [options]
Options:
--config PATH Path to qc-cli config file. Default: config.yaml
--calibration PATH Calibration .npz file or directory of .npy samples.
Default: ${CALIBRATION_PATH}
--input-file PATH Validation .npz or .npy inputs. Default: ${INPUT_FILE}
--from-step STEP Resume upload from: quantize, compile, validate, profile.
Default: ${FROM_STEP}
--from-job NAME SageMaker training job whose model artifact should upload.
Defaults to the last training job in local qc-cli state.
--model-s3-uri URI S3 URI of model.tar.gz to upload.
--onnx-path PATH Local ONNX path or ONNX path inside extracted artifact.
--input-name NAME Input name for .npy validation files.
--skip-download Do not download the compiled AI Hub artifact after upload.
--output PATH Destination file for ai-hub download.
-h, --help Show this help.
EOF
}
while [[ $# -gt 0 ]]; do
case "$1" in
--config)
CONFIG_PATH="$2"
shift 2
;;
--calibration)
CALIBRATION_PATH="$2"
shift 2
;;
--input-file)
INPUT_FILE="$2"
shift 2
;;
--from-step)
FROM_STEP="$2"
shift 2
;;
--from-job)
FROM_JOB="$2"
shift 2
;;
--model-s3-uri)
MODEL_S3_URI="$2"
shift 2
;;
--onnx-path)
ONNX_PATH="$2"
shift 2
;;
--input-name)
INPUT_NAME="$2"
shift 2
;;
--skip-download)
DOWNLOAD=false
shift
;;
--output)
OUTPUT_PATH="$2"
shift 2
;;
-h|--help)
usage
exit 0
;;
*)
echo "Unknown option: $1" >&2
usage >&2
exit 1
;;
esac
done
if [[ ! -f "${CONFIG_PATH}" ]]; then
echo "Config not found: ${CONFIG_PATH}" >&2
exit 1
fi
case "${FROM_STEP}" in
quantize|compile|validate|profile)
;;
*)
echo "--from-step must be one of: quantize, compile, validate, profile" >&2
exit 1
;;
esac
if [[ ! -e "${CALIBRATION_PATH}" ]]; then
echo "Calibration path not found: ${CALIBRATION_PATH}" >&2
echo "Pass --calibration with a .npz file or directory of .npy samples." >&2
exit 1
fi
if [[ ! -f "${INPUT_FILE}" ]]; then
echo "Input file not found: ${INPUT_FILE}" >&2
echo "Pass --input-file with a validation .npz or .npy file." >&2
exit 1
fi
run() {
echo "+ $*"
"$@"
}
UPLOAD_ARGS=(
"${CALIBRATION_PATH}"
"${INPUT_FILE}"
--from-step "${FROM_STEP}"
--config "${CONFIG_PATH}"
)
if [[ -n "${FROM_JOB}" ]]; then
UPLOAD_ARGS+=(--from-job "${FROM_JOB}")
fi
if [[ -n "${MODEL_S3_URI}" ]]; then
UPLOAD_ARGS+=(--model-s3-uri "${MODEL_S3_URI}")
fi
if [[ -n "${ONNX_PATH}" ]]; then
UPLOAD_ARGS+=(--onnx-path "${ONNX_PATH}")
fi
if [[ -n "${INPUT_NAME}" ]]; then
UPLOAD_ARGS+=(--input-name "${INPUT_NAME}")
fi
run uv run qc-cli ai-hub upload "${UPLOAD_ARGS[@]}"
if [[ "${DOWNLOAD}" == false ]]; then
exit 0
fi
DOWNLOAD_ARGS=(--config "${CONFIG_PATH}")
if [[ -n "${OUTPUT_PATH}" ]]; then
DOWNLOAD_ARGS+=(--output "${OUTPUT_PATH}")
fi
run uv run qc-cli ai-hub download "${DOWNLOAD_ARGS[@]}"

View File

@@ -126,6 +126,10 @@ def export_onnx(model: nn.Module, model_dir: Path, image_size: int) -> None:
do_constant_folding=True, do_constant_folding=True,
input_names=["input"], input_names=["input"],
output_names=["logits"], output_names=["logits"],
dynamic_axes={
"input": {0: "batch_size"},
"logits": {0: "batch_size"},
},
) )

View File

@@ -5,7 +5,7 @@ build-backend = "hatchling.build"
[project] [project]
name = "qc-cli" name = "qc-cli"
version = "0.1.0" version = "0.1.0"
description = "CLI for training and deploying models for Qualcomm AI Hub" description = "CLI for SageMaker ONNX training and Qualcomm AI Hub optimization"
requires-python = ">=3.13" requires-python = ">=3.13"
dependencies = [ dependencies = [
"aws-cdk-lib>=2.180.0", "aws-cdk-lib>=2.180.0",
@@ -29,6 +29,8 @@ packages = ["src"]
[dependency-groups] [dependency-groups]
dev = [ dev = [
"boto3-stubs[iam,s3,sagemaker]", "boto3-stubs[iam,s3,sagemaker]",
"pytest>=8.0",
"pytest-mock>=3.12",
"pyright>=1.1.409", "pyright>=1.1.409",
"types-PyYAML", "types-PyYAML",
"ruff>=0.4", "ruff>=0.4",

View File

@@ -1,6 +1,3 @@
import os
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, cast from typing import Any, cast
import boto3 import boto3
@@ -37,38 +34,3 @@ def create_presigned_tracking_server_url(region: str, profile: str, name: str) -
client = boto3.Session(profile_name=profile, region_name=region).client("sagemaker") client = boto3.Session(profile_name=profile, region_name=region).client("sagemaker")
response = client.create_presigned_mlflow_tracking_server_url(TrackingServerName=name) response = client.create_presigned_mlflow_tracking_server_url(TrackingServerName=name)
return str(response["AuthorizedUrl"]) return str(response["AuthorizedUrl"])
@contextmanager
def tracking_auth_env(profile: str, region: str) -> Generator[None]:
credentials = boto3.Session(profile_name=profile, region_name=region).get_credentials()
if credentials is None:
raise RuntimeError(f"AWS credentials could not be resolved for profile '{profile}'.")
frozen_credentials = credentials.get_frozen_credentials()
if not frozen_credentials.access_key or not frozen_credentials.secret_key:
raise RuntimeError(f"AWS credentials are incomplete for profile '{profile}'.")
env_updates = {
"AWS_PROFILE": profile,
"AWS_DEFAULT_REGION": region,
"AWS_REGION": region,
"AWS_ACCESS_KEY_ID": frozen_credentials.access_key,
"AWS_SECRET_ACCESS_KEY": frozen_credentials.secret_key,
}
if frozen_credentials.token:
env_updates["AWS_SESSION_TOKEN"] = frozen_credentials.token
restore_keys = set(env_updates) | {"AWS_SESSION_TOKEN"}
previous_env = {key: os.environ.get(key) for key in restore_keys}
try:
os.environ.update(env_updates)
if not frozen_credentials.token:
os.environ.pop("AWS_SESSION_TOKEN", None)
yield
finally:
for key, value in previous_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value

View File

@@ -1 +0,0 @@
"""Cloud provider adapters."""

View File

@@ -1,77 +0,0 @@
from contextlib import AbstractContextManager
from dataclasses import dataclass
from typing import Any, Protocol
from src.aws import mlflow as aws_mlflow
from src.config import Config
class MlflowTrackingBackend(Protocol):
@property
def provider_name(self) -> str: ...
@property
def profile(self) -> str: ...
@property
def region(self) -> str: ...
def get_tracking_uri(self, tracking_server_name: str) -> str: ...
def auth_env(self) -> AbstractContextManager[None]: ...
def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]: ...
def training_run_tags(self, training_job: Any) -> dict[str, Any]: ...
def training_status_params(self, training_job_status: Any) -> dict[str, Any]: ...
def model_version_tags(self, training_job_status: Any) -> dict[str, Any]: ...
@dataclass(frozen=True)
class AwsMlflowTrackingBackend:
profile: str
region: str
provider_name: str = "aws"
def get_tracking_uri(self, tracking_server_name: str) -> str:
return aws_mlflow.get_tracking_server_arn(self.region, self.profile, tracking_server_name)
def auth_env(self) -> AbstractContextManager[None]:
return aws_mlflow.tracking_auth_env(self.profile, self.region)
def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]:
return {
"provider.name": self.provider_name,
"provider.region": region,
"provider.profile": profile,
"sagemaker.role_arn": role_arn,
"sagemaker.job_name": training_job.job_name,
"sagemaker.training_image": training_job.image_uri,
"sagemaker.instance_type": training_job.instance_type,
"sagemaker.instance_count": training_job.instance_count,
"sagemaker.s3_train_uri": training_job.s3_train_uri,
"sagemaker.s3_output_path": training_job.s3_output_path,
"sagemaker.entry_point": training_job.entry_point,
"sagemaker.source_dir": training_job.source_dir,
}
def training_run_tags(self, training_job: Any) -> dict[str, Any]:
return {"sagemaker.job_name": training_job.job_name}
def training_status_params(self, training_job_status: Any) -> dict[str, Any]:
return {
"sagemaker.training_status": training_job_status.status,
"sagemaker.created_at": training_job_status.created,
"sagemaker.modified_at": training_job_status.modified,
"sagemaker.model_artifacts": training_job_status.model_artifacts,
"sagemaker.failure_reason": training_job_status.failure_reason,
}
def model_version_tags(self, training_job_status: Any) -> dict[str, Any]:
return {"sagemaker.job_name": training_job_status.name}
def mlflow_tracking_backend_from_config(cfg: Config) -> MlflowTrackingBackend:
return AwsMlflowTrackingBackend(profile=cfg.aws.profile, region=cfg.aws.region)

View File

@@ -4,9 +4,7 @@ from enum import StrEnum
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import qai_hub.hub as hub
import typer import typer
from qai_hub.client import Device
from src import state as state_ops from src import state as state_ops
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
@@ -101,33 +99,6 @@ def _model_id_or_state(config_path: str, model_id: str | None, *, quantized: boo
return resolved return resolved
def _device_selector(device: Device) -> str:
parts: list[str] = []
if device.name:
parts.append(f"name={device.name!r}")
if device.os:
parts.append(f"os={device.os!r}")
if device.attributes:
parts.append(f"attributes={device.attributes!r}")
return ", ".join(parts) if parts else "empty selector"
def _validate_device(cfg: Config) -> None:
device = cfg.aihub.device
try:
matches = hub.get_devices(name=device.name, os=device.os, attributes=device.attributes)
except Exception as e:
CONSOLE.print(f"[red]Unable to validate AI Hub device {_device_selector(device)}: {e}[/red]")
raise typer.Exit(1)
if matches:
return
CONSOLE.print(f"[red]AI Hub device not found: {_device_selector(device)}[/red]")
CONSOLE.print("Run [bold]qai-hub list-devices[/bold] to see valid device names.")
raise typer.Exit(1)
def _quantize_step( def _quantize_step(
cfg: Config, cfg: Config,
config_path: str, config_path: str,
@@ -185,7 +156,6 @@ def _compile_step(
prefer_quantized: bool, prefer_quantized: bool,
) -> str: ) -> str:
st = state_ops.store(config_path) st = state_ops.store(config_path)
_validate_device(cfg)
specs = _input_specs(cfg) specs = _input_specs(cfg)
model: Any model: Any
@@ -214,7 +184,7 @@ def _compile_step(
try: try:
result = aihub_jobs.submit_compile_job( result = aihub_jobs.submit_compile_job(
model=model, model=model,
device=cfg.aihub.device, device_name=cfg.aihub.device,
input_specs=specs, input_specs=specs,
target_runtime=cfg.aihub.target_runtime, target_runtime=cfg.aihub.target_runtime,
options=cfg.aihub.compile_options, options=cfg.aihub.compile_options,
@@ -244,7 +214,6 @@ def _validate_step(
model_id: str | None, model_id: str | None,
input_name: str | None, input_name: str | None,
) -> str: ) -> str:
_validate_device(cfg)
specs = _input_specs(cfg) specs = _input_specs(cfg)
resolved_model_id = _model_id_or_state(config_path, model_id) resolved_model_id = _model_id_or_state(config_path, model_id)
try: try:
@@ -278,7 +247,6 @@ def _validate_step(
def _profile_step(cfg: Config, config_path: str, model_id: str | None) -> str: def _profile_step(cfg: Config, config_path: str, model_id: str | None) -> str:
_validate_device(cfg)
resolved_model_id = _model_id_or_state(config_path, model_id) resolved_model_id = _model_id_or_state(config_path, model_id)
try: try:
result = aihub_jobs.submit_profile_job( result = aihub_jobs.submit_profile_job(

View File

@@ -150,6 +150,35 @@ def status(config: str = CONFIG_OPT) -> None:
CONSOLE.print(table) CONSOLE.print(table)
@app.command(name="mlflow-url")
def mlflow_url(config: str = CONFIG_OPT) -> None:
"""Print a presigned URL for the configured MLflow tracking server."""
cfg = load_cfg(config)
tracking_server_name = cfg.effective_mlflow_tracking_server_name
if not tracking_server_name:
CONSOLE.print("[red]MLflow is disabled in config.yaml.[/red]")
raise typer.Exit(1)
try:
url = mlflow.create_presigned_tracking_server_url(
cfg.aws.region,
cfg.aws.profile,
tracking_server_name,
)
except Exception as e:
CONSOLE.print("[yellow]Could not create a SageMaker MLflow UI URL.[/yellow]")
CONSOLE.print(f"Tracking server: [cyan]{tracking_server_name}[/cyan]")
CONSOLE.print(f"Reason: {e}")
CONSOLE.print(
"This command can create presigned URLs only for MLflow tracking servers managed by "
"Amazon SageMaker. If this is an external MLflow server, open it with that server's own URL."
)
raise typer.Exit(1)
CONSOLE.print(f"MLflow tracking server: [cyan]{tracking_server_name}[/cyan]")
CONSOLE.print(f"MLflow UI: {url}")
@app.command() @app.command()
def destroy( def destroy(
config: str = CONFIG_OPT, config: str = CONFIG_OPT,

View File

@@ -1,40 +0,0 @@
import secrets
from pathlib import Path
import typer
import yaml
from src.commands.utils import CONSOLE
from src.config import GENERATED_STACK_PREFIX, Config, InfraConfig, S3Config
app = typer.Typer()
@app.command()
def init(
output: str = typer.Option("config.yaml", help="Destination path for the config file"),
force: bool = typer.Option(False, "--force", "-f", help="Overwrite an existing config file"),
) -> None:
"""Write a starter config.yaml to the current directory."""
dest = Path(output)
if dest.exists() and not force:
CONSOLE.print(f"[yellow]{dest} already exists.[/yellow] Use --force to overwrite.")
raise typer.Exit(1)
config = _new_isolated_config()
dest.parent.mkdir(parents=True, exist_ok=True)
config_data = config.model_dump(mode="json")
config_data["sagemaker"].pop("role_name", None)
with open(dest, "w") as f:
yaml.safe_dump(config_data, f, sort_keys=False)
CONSOLE.print(f"[green]✓[/green] Config written to [bold]{dest}[/bold]")
CONSOLE.print("Edit [cyan]sagemaker.training.image_uri[/cyan] before running training commands.")
def _new_isolated_config() -> Config:
suffix = secrets.token_hex(6)
namespace = f"{GENERATED_STACK_PREFIX}{suffix}"
config = Config(infra=InfraConfig(stack_name=namespace))
config.s3 = S3Config(bucket=f"{namespace}-data")
return config

View File

@@ -1,41 +0,0 @@
import webbrowser
import typer
from src.aws import mlflow as aws_mlflow
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
app = typer.Typer(help="Manage MLflow tracking server access")
@app.command(name="open")
def open_mlflow(config: str = CONFIG_OPT) -> None:
"""Open a presigned URL for the configured MLflow tracking server."""
cfg = load_cfg(config)
tracking_server_name = cfg.effective_mlflow_tracking_server_name
if not tracking_server_name:
CONSOLE.print("[red]MLflow is disabled in config.yaml.[/red]")
raise typer.Exit(1)
try:
url = aws_mlflow.create_presigned_tracking_server_url(
cfg.aws.region,
cfg.aws.profile,
tracking_server_name,
)
except Exception as e:
CONSOLE.print("[yellow]Could not create a SageMaker MLflow UI URL.[/yellow]")
CONSOLE.print(f"Tracking server: [cyan]{tracking_server_name}[/cyan]")
CONSOLE.print(f"Reason: {e}")
CONSOLE.print(
"This command can create presigned URLs only for MLflow tracking servers managed by "
"Amazon SageMaker. If this is an external MLflow server, open it with that server's own URL."
)
raise typer.Exit(1)
CONSOLE.print(f"MLflow tracking server: [cyan]{tracking_server_name}[/cyan]")
CONSOLE.print(f"MLflow UI: {url}")
if webbrowser.open(url):
CONSOLE.print("[green]✓[/green] Opened MLflow UI in your browser.")
else:
CONSOLE.print("[yellow]Could not open a browser automatically. Open the URL above manually.[/yellow]")

View File

@@ -101,7 +101,7 @@ def start(config: str = CONFIG_OPT) -> None:
CONSOLE.print(f"[green]✓[/green] Job submitted: [bold]{job_name}[/bold]") CONSOLE.print(f"[green]✓[/green] Job submitted: [bold]{job_name}[/bold]")
if run_id: if run_id:
CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]") CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]")
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]") CONSOLE.print("Open MLflow: [cyan]qc-cli infra mlflow-url[/cyan]")
CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]") CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]")
@@ -151,7 +151,7 @@ def status(
st.set_latest_experiment_model_version(version) st.set_latest_experiment_model_version(version)
CONSOLE.print(f"MLflow model version: [cyan]{version}[/cyan] ([cyan]experiment-latest[/cyan])") CONSOLE.print(f"MLflow model version: [cyan]{version}[/cyan] ([cyan]experiment-latest[/cyan])")
if run_id and cfg.mlflow.mode is not MlflowMode.disabled: if run_id and cfg.mlflow.mode is not MlflowMode.disabled:
CONSOLE.print("Open MLflow: [cyan]qc-cli mlflow open[/cyan]") CONSOLE.print("Open MLflow: [cyan]qc-cli infra mlflow-url[/cyan]")
@app.command(name="list") @app.command(name="list")

View File

@@ -1,70 +0,0 @@
from pathlib import Path
import typer
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
from src.aws import s3 as s3_ops
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
app = typer.Typer()
@app.command()
def upload(
path: Path = typer.Argument(..., help="Local file or directory to upload"),
s3_key: str | None = typer.Option(None, "--s3-key", help="S3 key for file uploads"),
config: str = CONFIG_OPT,
) -> None:
"""Upload a local file or directory to S3."""
cfg = load_cfg(config)
if path.is_file():
key = s3_key or f"{cfg.s3.data_prefix.rstrip('/')}/{path.name}"
try:
with CONSOLE.status(f"Uploading {path.name}..."):
uri = s3_ops.upload_file(cfg.aws.region, cfg.aws.profile, cfg.s3.bucket, str(path), key)
except Exception as e:
CONSOLE.print(f"[red]Upload failed: {e}[/red]")
raise typer.Exit(1)
CONSOLE.print(f"[green]✓[/green] {path.name} -> {uri}")
return
if path.is_dir():
if s3_key is not None:
CONSOLE.print("[red]--s3-key can only be used when uploading a single file.[/red]")
raise typer.Exit(1)
files = [file for file in path.rglob("*") if file.is_file()]
if not files:
CONSOLE.print("[yellow]No files found in directory.[/yellow]")
raise typer.Exit(0)
prefix = cfg.s3.data_prefix
CONSOLE.print(f"Uploading {len(files)} files to s3://{cfg.s3.bucket}/{prefix.rstrip('/')}/")
try:
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
console=CONSOLE,
) as progress:
task = progress.add_task("Uploading...", total=len(files))
count = s3_ops.upload_dir(
cfg.aws.region,
cfg.aws.profile,
cfg.s3.bucket,
str(path),
prefix,
on_progress=lambda: progress.advance(task),
)
except Exception as e:
CONSOLE.print(f"[red]Upload failed: {e}[/red]")
raise typer.Exit(1)
CONSOLE.print(f"[green]✓[/green] Uploaded {count} files to s3://{cfg.s3.bucket}/{prefix.rstrip('/')}/")
return
CONSOLE.print(f"[red]Path not found: {path}[/red]")
raise typer.Exit(1)

View File

@@ -4,8 +4,7 @@ from typing import Any, Literal, TypedDict
from mypy_boto3_s3.literals import BucketLocationConstraintType from mypy_boto3_s3.literals import BucketLocationConstraintType
from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType from mypy_boto3_sagemaker.literals import TrainingInstanceTypeType
from pydantic import BaseModel, Field, field_validator, model_validator from pydantic import BaseModel, Field, model_validator
from qai_hub.client import Device
class MlflowMode(StrEnum): class MlflowMode(StrEnum):
@@ -82,7 +81,7 @@ class SageMakerConfig(BaseModel):
class AIHubConfig(BaseModel): class AIHubConfig(BaseModel):
device: Device = Field(default_factory=lambda: Device("Samsung Galaxy S25 (Family)")) device: str = "Samsung Galaxy S25 (Family)"
target_runtime: str = "tflite" target_runtime: str = "tflite"
input_specs: dict[str, tuple[list[int], str]] = Field(default_factory=dict) input_specs: dict[str, tuple[list[int], str]] = Field(default_factory=dict)
job_name: str | None = None job_name: str | None = None
@@ -92,13 +91,6 @@ class AIHubConfig(BaseModel):
quantize_options: str | None = None quantize_options: str | None = None
output_dir: str = "build/qai-hub" output_dir: str = "build/qai-hub"
@field_validator("device", mode="before")
@classmethod
def parse_device(cls, value: Any) -> Any:
if isinstance(value, str):
return Device(value)
return value
class MlflowConfig(BaseModel): class MlflowConfig(BaseModel):
mode: MlflowMode = MlflowMode.disabled mode: MlflowMode = MlflowMode.disabled

View File

@@ -1,14 +1,115 @@
import typer import secrets
from pathlib import Path
from src.commands import ai_hub, infra, init, mlflow, train, upload import typer
import yaml
from rich.console import Console
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
from src.aws import s3 as s3_ops
from src.commands import ai_hub, infra, train
from src.commands.utils import CONFIG_OPT, load_cfg
from src.config import GENERATED_STACK_PREFIX, Config, InfraConfig, S3Config
app = typer.Typer( app = typer.Typer(
help="qc-cli: End-to-end model managment for Qualcomm AI Hub.", help="qc-cli: End-to-end model managment for Qualcomm AI Hub.",
no_args_is_help=True, no_args_is_help=True,
) )
app.add_typer(init.app)
app.add_typer(upload.app)
app.add_typer(mlflow.app, name="mlflow")
app.add_typer(infra.app, name="infra") app.add_typer(infra.app, name="infra")
app.add_typer(train.app, name="train") app.add_typer(train.app, name="train")
app.add_typer(ai_hub.app, name="ai-hub") app.add_typer(ai_hub.app, name="ai-hub")
console = Console()
@app.command()
def init(
output: str = typer.Option("config.yaml", help="Destination path for the config file"),
force: bool = typer.Option(False, "--force", "-f", help="Overwrite an existing config file"),
) -> None:
"""Write a starter config.yaml to the current directory."""
dest = Path(output)
if dest.exists() and not force:
console.print(f"[yellow]{dest} already exists.[/yellow] Use --force to overwrite.")
raise typer.Exit(1)
config = _new_isolated_config()
dest.parent.mkdir(parents=True, exist_ok=True)
config_data = config.model_dump(mode="json")
config_data["sagemaker"].pop("role_name", None)
with open(dest, "w") as f:
yaml.safe_dump(config_data, f, sort_keys=False)
console.print(f"[green]✓[/green] Config written to [bold]{dest}[/bold]")
console.print(
"Edit [cyan]sagemaker.training.image_uri[/cyan] before running training commands."
)
def _new_isolated_config() -> Config:
suffix = secrets.token_hex(6)
namespace = f"{GENERATED_STACK_PREFIX}{suffix}"
config = Config(infra=InfraConfig(stack_name=namespace))
config.s3 = S3Config(bucket=f"{namespace}-data")
return config
@app.command()
def upload(
path: Path = typer.Argument(..., help="Local file or directory to upload"),
s3_key: str | None = typer.Option(None, "--s3-key", help="S3 key for file uploads"),
config: str = CONFIG_OPT,
) -> None:
"""Upload a local file or directory to S3."""
cfg = load_cfg(config)
if path.is_file():
key = s3_key or f"{cfg.s3.data_prefix.rstrip('/')}/{path.name}"
try:
with console.status(f"Uploading {path.name}..."):
uri = s3_ops.upload_file(cfg.aws.region, cfg.aws.profile, cfg.s3.bucket, str(path), key)
except Exception as e:
console.print(f"[red]Upload failed: {e}[/red]")
raise typer.Exit(1)
console.print(f"[green]✓[/green] {path.name} -> {uri}")
return
if path.is_dir():
if s3_key is not None:
console.print("[red]--s3-key can only be used when uploading a single file.[/red]")
raise typer.Exit(1)
files = [file for file in path.rglob("*") if file.is_file()]
if not files:
console.print("[yellow]No files found in directory.[/yellow]")
raise typer.Exit(0)
prefix = cfg.s3.data_prefix
console.print(f"Uploading {len(files)} files to s3://{cfg.s3.bucket}/{prefix.rstrip('/')}/")
try:
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
console=console,
) as progress:
task = progress.add_task("Uploading...", total=len(files))
count = s3_ops.upload_dir(
cfg.aws.region,
cfg.aws.profile,
cfg.s3.bucket,
str(path),
prefix,
on_progress=lambda: progress.advance(task),
)
except Exception as e:
console.print(f"[red]Upload failed: {e}[/red]")
raise typer.Exit(1)
console.print(f"[green]✓[/green] Uploaded {count} files to s3://{cfg.s3.bucket}/{prefix.rstrip('/')}/")
return
console.print(f"[red]Path not found: {path}[/red]")
raise typer.Exit(1)

View File

@@ -1,26 +1,32 @@
from pathlib import Path from pathlib import Path
from typing import Any, TypedDict from typing import Any
import qai_hub.hub as hub
from qai_hub.client import CompileJob, Device, InferenceJob, Model, ProfileJob, QuantizeDtype, QuantizeJob
class ModelJobResult(TypedDict): def _hub() -> Any:
job: CompileJob | QuantizeJob import qai_hub as hub
job_id: str
model: Model return hub
model_id: str
class InferenceJobResult(TypedDict): def _id(obj: Any) -> str:
job: InferenceJob for attr in ("model_id", "job_id", "id"):
job_id: str value = getattr(obj, attr, None)
outputs: Any if value:
return str(value)
return str(obj)
class ProfileJobResult(TypedDict): def _target_model(job: Any) -> Any:
job: ProfileJob if hasattr(job, "get_target_model"):
job_id: str return job.get_target_model()
model = getattr(job, "target_model", None)
if model is not None:
return model
return job
def get_model(model_id: str) -> Any:
return _hub().get_model(model_id)
def _dataset_entries(inputs: dict[str, Any]) -> dict[str, list[Any]]: def _dataset_entries(inputs: dict[str, Any]) -> dict[str, list[Any]]:
@@ -29,13 +35,14 @@ def _dataset_entries(inputs: dict[str, Any]) -> dict[str, list[Any]]:
def submit_compile_job( def submit_compile_job(
model: Any, model: Any,
device: Device, device_name: str,
input_specs: dict[str, tuple[tuple[int, ...], str]], input_specs: dict[str, tuple[tuple[int, ...], str]],
target_runtime: str, target_runtime: str,
options: str | None = None, options: str | None = None,
job_name: str | None = None, job_name: str | None = None,
model_name: str | None = None, model_name: str | None = None,
) -> ModelJobResult: ) -> dict[str, Any]:
hub = _hub()
compile_options = f"--target_runtime {target_runtime}" compile_options = f"--target_runtime {target_runtime}"
if options: if options:
compile_options = f"{compile_options} {options}" compile_options = f"{compile_options} {options}"
@@ -45,56 +52,58 @@ def submit_compile_job(
model_arg = str(model) model_arg = str(model)
elif isinstance(model, str): elif isinstance(model, str):
candidate = Path(model) candidate = Path(model)
model_arg = model if candidate.exists() or candidate.suffix else hub.get_model(model) model_arg = model if candidate.exists() or candidate.suffix else get_model(model)
if model_name and isinstance(model_arg, str) and Path(model_arg).exists(): if model_name and isinstance(model_arg, str) and Path(model_arg).exists():
model_arg = hub.upload_model(model_arg, name=model_name) model_arg = hub.upload_model(model_arg, name=model_name)
job = hub.submit_compile_job( job = hub.submit_compile_job(
model=model_arg, model=model_arg,
device=device, device=hub.Device(device_name),
name=job_name, name=job_name,
input_specs=input_specs, input_specs=input_specs,
options=compile_options, options=compile_options,
) )
target_model = job.get_target_model() target_model = _target_model(job)
if target_model is None: if target_model is None:
raise RuntimeError(f"Compile job {job.job_id} did not produce a target model.") raise RuntimeError(f"Compile job {_id(job)} did not produce a target model.")
return {"job": job, "job_id": str(job.job_id), "model": target_model, "model_id": str(target_model.model_id)} return {"job": job, "job_id": _id(job), "model": target_model, "model_id": _id(target_model)}
def submit_inference_job( def submit_inference_job(
model_id: str, model_id: str,
device: Device, device_name: str,
inputs: dict[str, Any], inputs: dict[str, Any],
output_dir: str | Path, output_dir: str | Path,
job_name: str | None = None, job_name: str | None = None,
) -> InferenceJobResult: ) -> dict[str, Any]:
hub = _hub()
job = hub.submit_inference_job( job = hub.submit_inference_job(
model=hub.get_model(model_id), model=get_model(model_id),
device=device, device=hub.Device(device_name),
inputs=_dataset_entries(inputs), inputs=_dataset_entries(inputs),
name=job_name, name=job_name,
) )
out = Path(output_dir) out = Path(output_dir)
out.mkdir(parents=True, exist_ok=True) out.mkdir(parents=True, exist_ok=True)
data = job.download_output_data(str(out)) data = job.download_output_data(str(out))
return {"job": job, "job_id": str(job.job_id), "outputs": data} return {"job": job, "job_id": _id(job), "outputs": data}
def submit_profile_job( def submit_profile_job(
model_id: str, model_id: str,
device: Device, device_name: str,
options: str | None = None, options: str | None = None,
job_name: str | None = None, job_name: str | None = None,
) -> ProfileJobResult: ) -> dict[str, Any]:
hub = _hub()
job = hub.submit_profile_job( job = hub.submit_profile_job(
model=hub.get_model(model_id), model=get_model(model_id),
device=device, device=hub.Device(device_name),
name=job_name, name=job_name,
options=options or "", options=options or "",
) )
return {"job": job, "job_id": str(job.job_id)} return {"job": job, "job_id": _id(job)}
def submit_quantize_job( def submit_quantize_job(
@@ -103,27 +112,33 @@ def submit_quantize_job(
options: str | None = None, options: str | None = None,
job_name: str | None = None, job_name: str | None = None,
model_name: str | None = None, model_name: str | None = None,
) -> ModelJobResult: ) -> dict[str, Any]:
hub = _hub()
model_arg = str(model) model_arg = str(model)
if model_name and Path(model_arg).exists(): if model_name and Path(model_arg).exists():
model_arg = hub.upload_model(model_arg, name=model_name) model_arg = hub.upload_model(model_arg, name=model_name)
job = hub.submit_quantize_job( job = hub.submit_quantize_job(
model=model_arg, model=model_arg,
calibration_data=_dataset_entries(calibration_data), calibration_data=_dataset_entries(calibration_data),
weights_dtype=QuantizeDtype.INT8, weights_dtype=hub.QuantizeDtype.INT8,
activations_dtype=QuantizeDtype.INT8, activations_dtype=hub.QuantizeDtype.INT8,
name=job_name, name=job_name,
options=options or "", options=options or "",
) )
target_model = job.get_target_model() target_model = _target_model(job)
if target_model is None: if target_model is None:
raise RuntimeError(f"Quantize job {job.job_id} did not produce a target model.") raise RuntimeError(f"Quantize job {_id(job)} did not produce a target model.")
return {"job": job, "job_id": str(job.job_id), "model": target_model, "model_id": str(target_model.model_id)} return {"job": job, "job_id": _id(job), "model": target_model, "model_id": _id(target_model)}
def download_model(model_id: str, output_path: str | Path) -> str: def download_model(model_id: str, output_path: str | Path) -> str:
dest = Path(output_path) dest = Path(output_path)
dest.parent.mkdir(parents=True, exist_ok=True) dest.parent.mkdir(parents=True, exist_ok=True)
model = hub.get_model(model_id) model = get_model(model_id)
if hasattr(model, "download"):
result = model.download(str(dest)) result = model.download(str(dest))
return str(result or dest) return str(result or dest)
if hasattr(model, "download_model"):
result = model.download_model(str(dest))
return str(result or dest)
raise RuntimeError("AI Hub model object does not expose a download method.")

View File

@@ -5,7 +5,7 @@ from typing import Any, Protocol
import mlflow import mlflow
from mlflow.tracking import MlflowClient from mlflow.tracking import MlflowClient
from src.cloud.mlflow import MlflowTrackingBackend, mlflow_tracking_backend_from_config from src.aws import mlflow as aws_mlflow
from src.config import Config, MlflowMode from src.config import Config, MlflowMode
@@ -30,7 +30,6 @@ class MlflowTracker:
experiment_name: str experiment_name: str
registered_model_name: str registered_model_name: str
register_trained_models: bool register_trained_models: bool
tracking_backend: MlflowTrackingBackend
@classmethod @classmethod
def from_config(cls, cfg: Config) -> Tracker: def from_config(cls, cfg: Config) -> Tracker:
@@ -43,10 +42,11 @@ class MlflowTracker:
if not tracking_server_name: if not tracking_server_name:
raise RuntimeError("MLflow tracking server name could not be resolved.") raise RuntimeError("MLflow tracking server name could not be resolved.")
tracking_backend = mlflow_tracking_backend_from_config(cfg) tracking_uri = aws_mlflow.get_tracking_server_arn(
cfg.aws.region,
tracking_uri = tracking_backend.get_tracking_uri(tracking_server_name) cfg.aws.profile,
with tracking_backend.auth_env(): tracking_server_name,
)
mlflow.set_tracking_uri(tracking_uri) mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment(cfg.mlflow.experiment_name) mlflow.set_experiment(cfg.mlflow.experiment_name)
@@ -55,30 +55,34 @@ class MlflowTracker:
experiment_name=cfg.mlflow.experiment_name, experiment_name=cfg.mlflow.experiment_name,
registered_model_name=cfg.mlflow.registered_model_name, registered_model_name=cfg.mlflow.registered_model_name,
register_trained_models=cfg.mlflow.register_trained_models, register_trained_models=cfg.mlflow.register_trained_models,
tracking_backend=tracking_backend,
) )
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
with self.tracking_backend.auth_env():
run = mlflow.start_run(run_name=training_job.job_name) run = mlflow.start_run(run_name=training_job.job_name)
run_id = str(run.info.run_id) run_id = str(run.info.run_id)
self._log_params( params = {
self.tracking_backend.training_run_params( "aws.region": region,
training_job, "aws.profile": profile,
region=region, "sagemaker.role_arn": role_arn,
profile=profile, "sagemaker.job_name": training_job.job_name,
role_arn=role_arn, "sagemaker.training_image": training_job.image_uri,
) "sagemaker.instance_type": training_job.instance_type,
) "sagemaker.instance_count": training_job.instance_count,
"sagemaker.s3_train_uri": training_job.s3_train_uri,
"sagemaker.s3_output_path": training_job.s3_output_path,
"sagemaker.entry_point": training_job.entry_point,
"sagemaker.source_dir": training_job.source_dir,
}
self._log_params(params)
self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()}) self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()})
mlflow.set_tags( mlflow.set_tags(
{ {
"qc_cli.stage": "experiment", "qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source", "qc_cli.artifact_kind": "trained_source",
"qc_cli.source": self.tracking_backend.provider_name, "qc_cli.source": "sagemaker",
"qc_cli.command": "train start", "qc_cli.command": "train start",
**self.tracking_backend.training_run_tags(training_job), "sagemaker.job_name": training_job.job_name,
} }
) )
mlflow.end_run() mlflow.end_run()
@@ -88,9 +92,16 @@ class MlflowTracker:
if not run_id: if not run_id:
return None return None
with self.tracking_backend.auth_env():
with mlflow.start_run(run_id=run_id): with mlflow.start_run(run_id=run_id):
self._log_params(self.tracking_backend.training_status_params(training_job_status)) self._log_params(
{
"sagemaker.training_status": training_job_status.status,
"sagemaker.created_at": training_job_status.created,
"sagemaker.modified_at": training_job_status.modified,
"sagemaker.model_artifacts": training_job_status.model_artifacts,
"sagemaker.failure_reason": training_job_status.failure_reason,
}
)
self._log_final_metrics(training_job_status.raw) self._log_final_metrics(training_job_status.raw)
mlflow.set_tag("qc_cli.command", "train status") mlflow.set_tag("qc_cli.command", "train status")
@@ -110,8 +121,8 @@ class MlflowTracker:
tags={ tags={
"qc_cli.stage": "experiment", "qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source", "qc_cli.artifact_kind": "trained_source",
"qc_cli.source": self.tracking_backend.provider_name, "qc_cli.source": "sagemaker",
**self.tracking_backend.model_version_tags(training_job_status), "sagemaker.job_name": training_job_status.name,
}, },
) )
version_number = str(version.version) version_number = str(version.version)

50
uv.lock generated
View File

@@ -1003,6 +1003,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/8a/db/55a262f3606bebcae07cc14095338471ad7c0bbcaa37707e6f0ee49725b7/importlib_resources-7.1.0-py3-none-any.whl", hash = "sha256:1bd7b48b4088eddb2cd16382150bb515af0bd2c70128194392725f82ad2c96a1", size = 37232, upload-time = "2026-04-12T16:36:08.219Z" }, { url = "https://files.pythonhosted.org/packages/8a/db/55a262f3606bebcae07cc14095338471ad7c0bbcaa37707e6f0ee49725b7/importlib_resources-7.1.0-py3-none-any.whl", hash = "sha256:1bd7b48b4088eddb2cd16382150bb515af0bd2c70128194392725f82ad2c96a1", size = 37232, upload-time = "2026-04-12T16:36:08.219Z" },
] ]
[[package]]
name = "iniconfig"
version = "2.3.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" },
]
[[package]] [[package]]
name = "itsdangerous" name = "itsdangerous"
version = "2.2.0" version = "2.2.0"
@@ -1665,6 +1674,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/ff/6e/cf826fae916b8658848d7b9f38d88da6396895c676e8086fc0988073aaf8/pillow-12.2.0-cp314-cp314t-win_arm64.whl", hash = "sha256:aa88ccfe4e32d362816319ed727a004423aab09c5cea43c01a4b435643fa34eb", size = 2556579, upload-time = "2026-04-01T14:45:52.529Z" }, { url = "https://files.pythonhosted.org/packages/ff/6e/cf826fae916b8658848d7b9f38d88da6396895c676e8086fc0988073aaf8/pillow-12.2.0-cp314-cp314t-win_arm64.whl", hash = "sha256:aa88ccfe4e32d362816319ed727a004423aab09c5cea43c01a4b435643fa34eb", size = 2556579, upload-time = "2026-04-01T14:45:52.529Z" },
] ]
[[package]]
name = "pluggy"
version = "1.6.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
]
[[package]] [[package]]
name = "prettytable" name = "prettytable"
version = "3.17.0" version = "3.17.0"
@@ -1945,6 +1963,34 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/16/6b/330d8ebae582b30c2959a1ef4c3bc344ebde48c2ff0c3f113c4710735e11/pyright-1.1.409-py3-none-any.whl", hash = "sha256:aa3ea228cab90c845c7a60d28db7a844c04315356392aa09fafcee98c8c22fb3", size = 6438161, upload-time = "2026-04-23T11:02:01.309Z" }, { url = "https://files.pythonhosted.org/packages/16/6b/330d8ebae582b30c2959a1ef4c3bc344ebde48c2ff0c3f113c4710735e11/pyright-1.1.409-py3-none-any.whl", hash = "sha256:aa3ea228cab90c845c7a60d28db7a844c04315356392aa09fafcee98c8c22fb3", size = 6438161, upload-time = "2026-04-23T11:02:01.309Z" },
] ]
[[package]]
name = "pytest"
version = "9.0.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "iniconfig" },
{ name = "packaging" },
{ name = "pluggy" },
{ name = "pygments" },
]
sdist = { url = "https://files.pythonhosted.org/packages/7d/0d/549bd94f1a0a402dc8cf64563a117c0f3765662e2e668477624baeec44d5/pytest-9.0.3.tar.gz", hash = "sha256:b86ada508af81d19edeb213c681b1d48246c1a91d304c6c81a427674c17eb91c", size = 1572165, upload-time = "2026-04-07T17:16:18.027Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d4/24/a372aaf5c9b7208e7112038812994107bc65a84cd00e0354a88c2c77a617/pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9", size = 375249, upload-time = "2026-04-07T17:16:16.13Z" },
]
[[package]]
name = "pytest-mock"
version = "3.15.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pytest" },
]
sdist = { url = "https://files.pythonhosted.org/packages/68/14/eb014d26be205d38ad5ad20d9a80f7d201472e08167f0bb4361e251084a9/pytest_mock-3.15.1.tar.gz", hash = "sha256:1849a238f6f396da19762269de72cb1814ab44416fa73a8686deac10b0d87a0f", size = 34036, upload-time = "2025-09-16T16:37:27.081Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/5a/cc/06253936f4a7fa2e0f48dfe6d851d9c56df896a9ab09ac019d70b760619c/pytest_mock-3.15.1-py3-none-any.whl", hash = "sha256:0a25e2eb88fe5168d535041d09a4529a188176ae608a6d249ee65abc0949630d", size = 10095, upload-time = "2025-09-16T16:37:25.734Z" },
]
[[package]] [[package]]
name = "python-dateutil" name = "python-dateutil"
version = "2.9.0.post0" version = "2.9.0.post0"
@@ -2068,6 +2114,8 @@ dependencies = [
dev = [ dev = [
{ name = "boto3-stubs", extra = ["iam", "s3", "sagemaker"] }, { name = "boto3-stubs", extra = ["iam", "s3", "sagemaker"] },
{ name = "pyright" }, { name = "pyright" },
{ name = "pytest" },
{ name = "pytest-mock" },
{ name = "ruff" }, { name = "ruff" },
{ name = "types-pyyaml" }, { name = "types-pyyaml" },
] ]
@@ -2090,6 +2138,8 @@ requires-dist = [
dev = [ dev = [
{ name = "boto3-stubs", extras = ["iam", "s3", "sagemaker"] }, { name = "boto3-stubs", extras = ["iam", "s3", "sagemaker"] },
{ name = "pyright", specifier = ">=1.1.409" }, { name = "pyright", specifier = ">=1.1.409" },
{ name = "pytest", specifier = ">=8.0" },
{ name = "pytest-mock", specifier = ">=3.12" },
{ name = "ruff", specifier = ">=0.4" }, { name = "ruff", specifier = ">=0.4" },
{ name = "types-pyyaml" }, { name = "types-pyyaml" },
] ]