1 Commits

Author SHA1 Message Date
3846c5d88d add aws context for MLFlow 2026-06-05 15:52:55 -04:00
12 changed files with 394 additions and 406 deletions

View File

@@ -13,10 +13,18 @@ This example takes the ONNX model produced by the SageMaker training example and
Run the training example first and wait for it to complete:
```bash
examples/training/run_training.sh --wait
bash examples/training/run_training.sh --config config.yaml --wait
```
The `config.yaml` file must include AI Hub settings:
If the dataset is already uploaded to S3, use:
```bash
bash examples/training/run_training.sh --config config.yaml --skip-upload --wait
```
The training artifact must contain a static-shape `model.onnx`. The training example exports an input named `input` with shape `1x3x160x160`.
Your `config.yaml` must include AI Hub settings:
```yaml
aihub:
@@ -28,20 +36,16 @@ aihub:
output_dir: build/qai-hub
```
Finally, the user needs to authenticate with Qualcomm AI Hub using:
```bash
qai-hub configure --api_token
```
You also need local Qualcomm AI Hub SDK authentication configured.
## Prepare Inputs
AI Hub does not consume the raw JPG training images directly. It needs NumPy tensors that match the ONNX model input shape and preprocessing.
To generate calibration and validation inputs:
Generate calibration and validation inputs:
```bash
python examples/ai-hub/prepare_inputs.py
uv run python examples/ai-hub/prepare_inputs.py
```
This writes:
@@ -57,23 +61,58 @@ The script applies the same image preprocessing used by the training example:
- convert to channel-first `1x3x160x160`
- normalize with ImageNet mean and standard deviation
## Upload Model to Qualcomm Workbench
The model can be uploaded to Qualcomm Workbench using:
Useful options:
```bash
qc-cli ai-hub upload examples/training/data/aihub_calibration examples/training/data/inputs.npz
uv run python examples/ai-hub/prepare_inputs.py \
--dataset-dir examples/training/data/flower_photos_sagemaker \
--calibration-dir examples/training/data/aihub_calibration \
--input-file examples/training/data/inputs.npz \
--samples 16
```
The first argument is the calibration path for the model and the second argument is the input file, both of which were created by the `prepare_inputs.py` script. For more details, add `--help` after the `upload` command.
## Run AI Hub
The `upload` command runs the following commands in order:
1. `qc-cli ai-hub quantize`
2. `qc-cli ai-hub compile`
3. `qc-cli ai-hub validate`
4. `qc-cli ai-hub profile`
After training completes and inputs are prepared:
Finally the user can download the model from AI Workbench using the command
```bash
qc-cli ai-hub download
bash examples/ai-hub/run_ai_hub.sh --config config.yaml
```
By default, the script uses the last SageMaker training job recorded in `.qc-cli.json`. It downloads that job's `model.tar.gz`, extracts `model.onnx`, runs the AI Hub workflow, and downloads the compiled artifact.
To use a specific training job:
```bash
bash examples/ai-hub/run_ai_hub.sh \
--config config.yaml \
--from-job qc-cli-YYYYMMDD-HHMMSS
```
To resume from a later Workbench step:
```bash
bash examples/ai-hub/run_ai_hub.sh \
--config config.yaml \
--from-step validate
```
To skip downloading the compiled artifact:
```bash
bash examples/ai-hub/run_ai_hub.sh \
--config config.yaml \
--skip-download
```
## Troubleshooting
If AI Hub reports dynamic input shapes, rerun training with the current training source. AI Hub quantization requires the exported ONNX model to use static input shapes.
If `run_ai_hub.sh` reports missing calibration or input files, run:
```bash
uv run python examples/ai-hub/prepare_inputs.py
```
If validation fails with a missing input name, make sure `config.yaml` and the generated `.npz` both use `input` as the input name.

0
examples/ai-hub/prepare_inputs.py Normal file → Executable file
View File

156
examples/ai-hub/run_ai_hub.sh Executable file
View File

@@ -0,0 +1,156 @@
#!/usr/bin/env bash
set -euo pipefail
CONFIG_PATH="config.yaml"
CALIBRATION_PATH="examples/training/data/aihub_calibration"
INPUT_FILE="examples/training/data/inputs.npz"
FROM_STEP="quantize"
FROM_JOB=""
MODEL_S3_URI=""
ONNX_PATH=""
INPUT_NAME=""
DOWNLOAD=true
OUTPUT_PATH=""
usage() {
cat <<EOF
Usage: $0 [options]
Options:
--config PATH Path to qc-cli config file. Default: config.yaml
--calibration PATH Calibration .npz file or directory of .npy samples.
Default: ${CALIBRATION_PATH}
--input-file PATH Validation .npz or .npy inputs. Default: ${INPUT_FILE}
--from-step STEP Resume upload from: quantize, compile, validate, profile.
Default: ${FROM_STEP}
--from-job NAME SageMaker training job whose model artifact should upload.
Defaults to the last training job in local qc-cli state.
--model-s3-uri URI S3 URI of model.tar.gz to upload.
--onnx-path PATH Local ONNX path or ONNX path inside extracted artifact.
--input-name NAME Input name for .npy validation files.
--skip-download Do not download the compiled AI Hub artifact after upload.
--output PATH Destination file for ai-hub download.
-h, --help Show this help.
EOF
}
while [[ $# -gt 0 ]]; do
case "$1" in
--config)
CONFIG_PATH="$2"
shift 2
;;
--calibration)
CALIBRATION_PATH="$2"
shift 2
;;
--input-file)
INPUT_FILE="$2"
shift 2
;;
--from-step)
FROM_STEP="$2"
shift 2
;;
--from-job)
FROM_JOB="$2"
shift 2
;;
--model-s3-uri)
MODEL_S3_URI="$2"
shift 2
;;
--onnx-path)
ONNX_PATH="$2"
shift 2
;;
--input-name)
INPUT_NAME="$2"
shift 2
;;
--skip-download)
DOWNLOAD=false
shift
;;
--output)
OUTPUT_PATH="$2"
shift 2
;;
-h|--help)
usage
exit 0
;;
*)
echo "Unknown option: $1" >&2
usage >&2
exit 1
;;
esac
done
if [[ ! -f "${CONFIG_PATH}" ]]; then
echo "Config not found: ${CONFIG_PATH}" >&2
exit 1
fi
case "${FROM_STEP}" in
quantize|compile|validate|profile)
;;
*)
echo "--from-step must be one of: quantize, compile, validate, profile" >&2
exit 1
;;
esac
if [[ ! -e "${CALIBRATION_PATH}" ]]; then
echo "Calibration path not found: ${CALIBRATION_PATH}" >&2
echo "Pass --calibration with a .npz file or directory of .npy samples." >&2
exit 1
fi
if [[ ! -f "${INPUT_FILE}" ]]; then
echo "Input file not found: ${INPUT_FILE}" >&2
echo "Pass --input-file with a validation .npz or .npy file." >&2
exit 1
fi
run() {
echo "+ $*"
"$@"
}
UPLOAD_ARGS=(
"${CALIBRATION_PATH}"
"${INPUT_FILE}"
--from-step "${FROM_STEP}"
--config "${CONFIG_PATH}"
)
if [[ -n "${FROM_JOB}" ]]; then
UPLOAD_ARGS+=(--from-job "${FROM_JOB}")
fi
if [[ -n "${MODEL_S3_URI}" ]]; then
UPLOAD_ARGS+=(--model-s3-uri "${MODEL_S3_URI}")
fi
if [[ -n "${ONNX_PATH}" ]]; then
UPLOAD_ARGS+=(--onnx-path "${ONNX_PATH}")
fi
if [[ -n "${INPUT_NAME}" ]]; then
UPLOAD_ARGS+=(--input-name "${INPUT_NAME}")
fi
run uv run qc-cli ai-hub upload "${UPLOAD_ARGS[@]}"
if [[ "${DOWNLOAD}" == false ]]; then
exit 0
fi
DOWNLOAD_ARGS=(--config "${CONFIG_PATH}")
if [[ -n "${OUTPUT_PATH}" ]]; then
DOWNLOAD_ARGS+=(--output "${OUTPUT_PATH}")
fi
run uv run qc-cli ai-hub download "${DOWNLOAD_ARGS[@]}"

View File

@@ -1,186 +0,0 @@
# YOLO26 Electric Meter Detection Example
This example trains a YOLO26 object detection model on the Roboflow Universe electric meter dataset using the existing `qc-cli` SageMaker training flow.
The workflow is intentionally command driven. Run each step yourself so you can inspect the dataset, update `config.yaml`, and decide when to submit the SageMaker job.
Dataset:
```text
https://universe.roboflow.com/kemals-workspace-kbc8l/electric-meter-detection-o4tfi/dataset/1
```
## Prerequisites
- Install or sync the project dependencies: `uv sync`
- The virtual environment is activated.
- AWS credentials configured for the profile in `config.yaml`
- Infrastructure already deployed with `qc-cli infra setup`
## 1. Download The Dataset
Register or sign in to Roboflow, then open the dataset page:
```text
https://universe.roboflow.com/kemals-workspace-kbc8l/electric-meter-detection-o4tfi/dataset/1
```
Download the dataset in YOLOv26 format from the Roboflow UI, then extract the downloaded archive into:
```text
examples/meter-detection/data/electric-meter-detection
```
The `data.yaml` file should be directly under that folder:
```text
examples/meter-detection/data/electric-meter-detection/data.yaml
```
Do not move `data.yaml` into the `train/` split folder.
After extracting, confirm the dataset has a YOLO data file and image splits:
```bash
find examples/meter-detection/data/electric-meter-detection -maxdepth 2 -type d | sort
find examples/meter-detection/data/electric-meter-detection -name data.yaml -print
```
Open `examples/meter-detection/data/electric-meter-detection/data.yaml` and make sure the split paths are relative to that folder:
```yaml
path: .
train: train/images
val: valid/images
test: test/images
```
If your downloaded dataset does not include a `test/` folder, remove the `test:` line.
The expected layout is similar to:
```text
examples/meter-detection/data/electric-meter-detection/
data.yaml
train/
valid/
test/
```
## 2. Configure SageMaker Training
Update `config.yaml` so the training section points at this example's source directory:
```yaml
sagemaker:
training:
image_uri: 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.6-cpu-py312-ubuntu22.04-sagemaker-v1
instance_type: ml.g4dn.xlarge
instance_count: 1
source_dir: examples/meter-detection/source
entry_point: train.py
hyperparameters:
model: yolo26n.pt
epochs: 25
imgsz: 640
batch: 16
workers: 2
```
Use `yolo26n.pt` for a lightweight first YOLO26 run. If those weights are unavailable in the installed Ultralytics package, use `yolo11n.pt` as the established fallback:
```yaml
model: yolo11n.pt
```
The `source/requirements.txt` file is installed by the SageMaker PyTorch container before running `train.py`.
For a CPU smoke test, use a CPU instance and reduce the workload:
```yaml
sagemaker:
training:
image_uri: 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.6-cpu-py312-ubuntu22.04-sagemaker-v1
instance_type: ml.m4.xlarge
instance_count: 1
source_dir: examples/meter-detection/source
entry_point: train.py
hyperparameters:
model: yolo26n.pt
epochs: 1
imgsz: 320
batch: 4
workers: 2
```
## 3. Check Infrastructure
Confirm the CLI can see the configured SageMaker role and S3 bucket:
```bash
qc-cli infra status --config config.yaml
```
## 4. Upload The Dataset
Upload the downloaded Roboflow dataset to the `s3.data_prefix` configured in `config.yaml`:
```bash
qc-cli upload examples/meter-detection/data/electric-meter-detection
```
Directory uploads preserve paths relative to the uploaded directory, so SageMaker receives the dataset root with `data.yaml` plus the split directories.
In SageMaker, this uploaded dataset root is mounted at `/opt/ml/input/data/train`. That `train` path is the SageMaker channel name, not the YOLO `train/` split folder.
## 5. Start Training
Submit the SageMaker training job:
```bash
qc-cli train start
```
The command prints the submitted SageMaker job name. Check progress with:
```bash
qc-cli train status
```
Or pass the job name explicitly:
```bash
qc-cli train status qc-cli-YYYYMMDD-HHMMSS
```
## Outputs
When the job completes, SageMaker packages the files written under `/opt/ml/model` into `model.tar.gz`.
This example writes:
```text
best.pt
model.onnx
metrics.json
```
The archive is stored under the configured `s3.model_prefix`.
## Training Hyperparameters
Values under `sagemaker.training.hyperparameters` are passed to `source/train.py` as command-line arguments.
| Name | Type | Default | Description |
|---|---:|---:|---|
| `model` | string | `yolo26n.pt` | Ultralytics model weights or model YAML. |
| `epochs` | int | `25` | Number of training epochs. |
| `imgsz` | int | `640` | Square training image size. |
| `batch` | int | `16` | Images per training batch. |
| `workers` | int | `2` | DataLoader worker count. |
| `patience` | int | `20` | Early stopping patience. |
| `device` | string | auto | Optional Ultralytics device value such as `0` or `cpu`. |
| `data-yaml` | string | auto | Optional path to `data.yaml`; normally discovered from the uploaded dataset root. |
| `dataset-dir` | string | `SM_CHANNEL_TRAIN` | Uploaded dataset root mounted by SageMaker. |
Do not set `dataset-dir` or `model-dir` in normal SageMaker runs. SageMaker sets those automatically through `SM_CHANNEL_TRAIN` and `SM_MODEL_DIR`.

View File

@@ -1,3 +0,0 @@
ultralytics>=8.3.0
pyyaml>=6.0.3
onnx>=1.16.0

View File

@@ -1,124 +0,0 @@
#!/usr/bin/env python3
"""SageMaker entry point for YOLO electric meter detection training."""
from __future__ import annotations
import argparse
import json
import os
import shutil
from pathlib import Path
from typing import Any
import yaml
from ultralytics import YOLO # type: ignore[reportMissingImports]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="yolo26n.pt")
parser.add_argument("--epochs", type=int, default=25)
parser.add_argument("--imgsz", type=int, default=640)
parser.add_argument("--batch", type=int, default=16)
parser.add_argument("--workers", type=int, default=2)
parser.add_argument("--patience", type=int, default=20)
parser.add_argument("--device", default=None)
parser.add_argument("--data-yaml", default=None)
parser.add_argument("--dataset-dir", default=os.environ.get("SM_CHANNEL_TRAIN", "/opt/ml/input/data/train"))
parser.add_argument("--train-dir", dest="dataset_dir", help=argparse.SUPPRESS)
parser.add_argument("--model-dir", default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
return parser.parse_args()
def find_data_yaml(dataset_dir: Path, explicit_path: str | None) -> Path:
if explicit_path:
data_yaml = Path(explicit_path)
if data_yaml.is_file():
return data_yaml
raise FileNotFoundError(f"Configured data.yaml does not exist: {data_yaml}")
matches = sorted(dataset_dir.rglob("data.yaml"))
if not matches:
raise FileNotFoundError(f"Could not find data.yaml under {dataset_dir}")
if len(matches) > 1:
print(f"Found multiple data.yaml files; using {matches[0]}")
return matches[0]
def prepare_data_yaml(data_yaml: Path) -> Path:
"""Write a SageMaker-local data file rooted at the uploaded dataset."""
dataset_root = data_yaml.parent
data = yaml.safe_load(data_yaml.read_text(encoding="utf-8"))
if not isinstance(data, dict):
raise ValueError(f"Expected a mapping in {data_yaml}")
normalized = dict(data)
normalized["path"] = str(dataset_root)
if "val" not in normalized and "valid" in normalized:
normalized["val"] = normalized.pop("valid")
prepared_path = dataset_root / "data.sagemaker.yaml"
prepared_path.write_text(yaml.safe_dump(normalized, sort_keys=False), encoding="utf-8")
print(f"Prepared dataset config: {prepared_path}")
return prepared_path
def copy_if_exists(source: Path, destination: Path) -> None:
if source.exists():
shutil.copy2(source, destination)
print(f"Saved {destination}")
def main() -> None:
args = parse_args()
dataset_dir = Path(args.dataset_dir)
model_dir = Path(args.model_dir)
model_dir.mkdir(parents=True, exist_ok=True)
data_yaml = prepare_data_yaml(find_data_yaml(dataset_dir, args.data_yaml))
model = YOLO(args.model)
train_kwargs: dict[str, Any] = {
"data": str(data_yaml),
"epochs": args.epochs,
"imgsz": args.imgsz,
"batch": args.batch,
"workers": args.workers,
"patience": args.patience,
"project": str(model_dir / "runs"),
"name": "train",
"exist_ok": True,
}
if args.device:
train_kwargs["device"] = args.device
results = model.train(**train_kwargs)
save_dir = Path(results.save_dir)
best_pt = save_dir / "weights" / "best.pt"
last_pt = save_dir / "weights" / "last.pt"
trained_weights = best_pt if best_pt.exists() else last_pt
if not trained_weights.exists():
raise FileNotFoundError(f"Could not find trained weights in {save_dir / 'weights'}")
copy_if_exists(trained_weights, model_dir / "best.pt")
trained_model = YOLO(str(trained_weights))
onnx_path = Path(trained_model.export(format="onnx", imgsz=args.imgsz))
copy_if_exists(onnx_path, model_dir / "model.onnx")
metrics = {
"model": args.model,
"epochs": args.epochs,
"imgsz": args.imgsz,
"batch": args.batch,
"workers": args.workers,
"patience": args.patience,
"data_yaml": str(data_yaml),
"weights": str(trained_weights),
"onnx": str(onnx_path),
}
(model_dir / "metrics.json").write_text(json.dumps(metrics, indent=2), encoding="utf-8")
print(f"Saved model artifacts to {model_dir}")
if __name__ == "__main__":
main()

View File

@@ -1,3 +1,6 @@
import os
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, cast
import boto3
@@ -34,3 +37,38 @@ def create_presigned_tracking_server_url(region: str, profile: str, name: str) -
client = boto3.Session(profile_name=profile, region_name=region).client("sagemaker")
response = client.create_presigned_mlflow_tracking_server_url(TrackingServerName=name)
return str(response["AuthorizedUrl"])
@contextmanager
def tracking_auth_env(profile: str, region: str) -> Generator[None]:
credentials = boto3.Session(profile_name=profile, region_name=region).get_credentials()
if credentials is None:
raise RuntimeError(f"AWS credentials could not be resolved for profile '{profile}'.")
frozen_credentials = credentials.get_frozen_credentials()
if not frozen_credentials.access_key or not frozen_credentials.secret_key:
raise RuntimeError(f"AWS credentials are incomplete for profile '{profile}'.")
env_updates = {
"AWS_PROFILE": profile,
"AWS_DEFAULT_REGION": region,
"AWS_REGION": region,
"AWS_ACCESS_KEY_ID": frozen_credentials.access_key,
"AWS_SECRET_ACCESS_KEY": frozen_credentials.secret_key,
}
if frozen_credentials.token:
env_updates["AWS_SESSION_TOKEN"] = frozen_credentials.token
restore_keys = set(env_updates) | {"AWS_SESSION_TOKEN"}
previous_env = {key: os.environ.get(key) for key in restore_keys}
try:
os.environ.update(env_updates)
if not frozen_credentials.token:
os.environ.pop("AWS_SESSION_TOKEN", None)
yield
finally:
for key, value in previous_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value

View File

@@ -0,0 +1 @@
"""Cloud provider adapters."""

77
src/cloud/mlflow.py Normal file
View File

@@ -0,0 +1,77 @@
from contextlib import AbstractContextManager
from dataclasses import dataclass
from typing import Any, Protocol
from src.aws import mlflow as aws_mlflow
from src.config import Config
class MlflowTrackingBackend(Protocol):
@property
def provider_name(self) -> str: ...
@property
def profile(self) -> str: ...
@property
def region(self) -> str: ...
def get_tracking_uri(self, tracking_server_name: str) -> str: ...
def auth_env(self) -> AbstractContextManager[None]: ...
def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]: ...
def training_run_tags(self, training_job: Any) -> dict[str, Any]: ...
def training_status_params(self, training_job_status: Any) -> dict[str, Any]: ...
def model_version_tags(self, training_job_status: Any) -> dict[str, Any]: ...
@dataclass(frozen=True)
class AwsMlflowTrackingBackend:
profile: str
region: str
provider_name: str = "aws"
def get_tracking_uri(self, tracking_server_name: str) -> str:
return aws_mlflow.get_tracking_server_arn(self.region, self.profile, tracking_server_name)
def auth_env(self) -> AbstractContextManager[None]:
return aws_mlflow.tracking_auth_env(self.profile, self.region)
def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]:
return {
"provider.name": self.provider_name,
"provider.region": region,
"provider.profile": profile,
"sagemaker.role_arn": role_arn,
"sagemaker.job_name": training_job.job_name,
"sagemaker.training_image": training_job.image_uri,
"sagemaker.instance_type": training_job.instance_type,
"sagemaker.instance_count": training_job.instance_count,
"sagemaker.s3_train_uri": training_job.s3_train_uri,
"sagemaker.s3_output_path": training_job.s3_output_path,
"sagemaker.entry_point": training_job.entry_point,
"sagemaker.source_dir": training_job.source_dir,
}
def training_run_tags(self, training_job: Any) -> dict[str, Any]:
return {"sagemaker.job_name": training_job.job_name}
def training_status_params(self, training_job_status: Any) -> dict[str, Any]:
return {
"sagemaker.training_status": training_job_status.status,
"sagemaker.created_at": training_job_status.created,
"sagemaker.modified_at": training_job_status.modified,
"sagemaker.model_artifacts": training_job_status.model_artifacts,
"sagemaker.failure_reason": training_job_status.failure_reason,
}
def model_version_tags(self, training_job_status: Any) -> dict[str, Any]:
return {"sagemaker.job_name": training_job_status.name}
def mlflow_tracking_backend_from_config(cfg: Config) -> MlflowTrackingBackend:
return AwsMlflowTrackingBackend(profile=cfg.aws.profile, region=cfg.aws.region)

View File

@@ -14,7 +14,7 @@ from src.config import Config
from src.qualcomm import aihub_jobs
from src.qualcomm.artifacts import resolve_onnx
app = typer.Typer(help="Quantize, compile, validate, profile, and download models with Qualcomm Workbench")
app = typer.Typer(help="Quantize, compile, validate, profile, and download models with Qualcomm AI Hub")
_RUNTIME_EXTENSIONS = {
"tflite": "tflite",

View File

@@ -0,0 +1 @@

View File

@@ -5,7 +5,7 @@ from typing import Any, Protocol
import mlflow
from mlflow.tracking import MlflowClient
from src.aws import mlflow as aws_mlflow
from src.cloud.mlflow import MlflowTrackingBackend, mlflow_tracking_backend_from_config
from src.config import Config, MlflowMode
@@ -30,6 +30,7 @@ class MlflowTracker:
experiment_name: str
registered_model_name: str
register_trained_models: bool
tracking_backend: MlflowTrackingBackend
@classmethod
def from_config(cls, cfg: Config) -> Tracker:
@@ -42,11 +43,10 @@ class MlflowTracker:
if not tracking_server_name:
raise RuntimeError("MLflow tracking server name could not be resolved.")
tracking_uri = aws_mlflow.get_tracking_server_arn(
cfg.aws.region,
cfg.aws.profile,
tracking_server_name,
)
tracking_backend = mlflow_tracking_backend_from_config(cfg)
tracking_uri = tracking_backend.get_tracking_uri(tracking_server_name)
with tracking_backend.auth_env():
mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment(cfg.mlflow.experiment_name)
@@ -55,34 +55,30 @@ class MlflowTracker:
experiment_name=cfg.mlflow.experiment_name,
registered_model_name=cfg.mlflow.registered_model_name,
register_trained_models=cfg.mlflow.register_trained_models,
tracking_backend=tracking_backend,
)
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
with self.tracking_backend.auth_env():
run = mlflow.start_run(run_name=training_job.job_name)
run_id = str(run.info.run_id)
params = {
"aws.region": region,
"aws.profile": profile,
"sagemaker.role_arn": role_arn,
"sagemaker.job_name": training_job.job_name,
"sagemaker.training_image": training_job.image_uri,
"sagemaker.instance_type": training_job.instance_type,
"sagemaker.instance_count": training_job.instance_count,
"sagemaker.s3_train_uri": training_job.s3_train_uri,
"sagemaker.s3_output_path": training_job.s3_output_path,
"sagemaker.entry_point": training_job.entry_point,
"sagemaker.source_dir": training_job.source_dir,
}
self._log_params(params)
self._log_params(
self.tracking_backend.training_run_params(
training_job,
region=region,
profile=profile,
role_arn=role_arn,
)
)
self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()})
mlflow.set_tags(
{
"qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source",
"qc_cli.source": "sagemaker",
"qc_cli.source": self.tracking_backend.provider_name,
"qc_cli.command": "train start",
"sagemaker.job_name": training_job.job_name,
**self.tracking_backend.training_run_tags(training_job),
}
)
mlflow.end_run()
@@ -92,16 +88,9 @@ class MlflowTracker:
if not run_id:
return None
with self.tracking_backend.auth_env():
with mlflow.start_run(run_id=run_id):
self._log_params(
{
"sagemaker.training_status": training_job_status.status,
"sagemaker.created_at": training_job_status.created,
"sagemaker.modified_at": training_job_status.modified,
"sagemaker.model_artifacts": training_job_status.model_artifacts,
"sagemaker.failure_reason": training_job_status.failure_reason,
}
)
self._log_params(self.tracking_backend.training_status_params(training_job_status))
self._log_final_metrics(training_job_status.raw)
mlflow.set_tag("qc_cli.command", "train status")
@@ -121,8 +110,8 @@ class MlflowTracker:
tags={
"qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source",
"qc_cli.source": "sagemaker",
"sagemaker.job_name": training_job_status.name,
"qc_cli.source": self.tracking_backend.provider_name,
**self.tracking_backend.model_version_tags(training_job_status),
},
)
version_number = str(version.version)