command to create presigned URL for MLFlow

This commit is contained in:
2026-05-27 10:52:08 -04:00
parent e1c8d6574f
commit 58681cef82
6 changed files with 64 additions and 6 deletions

View File

@@ -28,3 +28,9 @@ def get_tracking_server_arn(region: str, profile: str, name: str) -> str:
if not arn:
raise ValueError(f"MLflow tracking server has no ARN: {name}")
return str(arn)
def create_presigned_tracking_server_url(region: str, profile: str, name: str) -> str:
client = boto3.Session(profile_name=profile, region_name=region).client("sagemaker")
response = client.create_presigned_mlflow_tracking_server_url(TrackingServerName=name)
return str(response["AuthorizedUrl"])

View File

@@ -150,6 +150,32 @@ def status(config: str = CONFIG_OPT) -> None:
CONSOLE.print(table)
@app.command(name="mlflow-url")
def mlflow_url(config: str = CONFIG_OPT) -> None:
"""Print a presigned URL for the configured MLflow tracking server."""
cfg = load_cfg(config)
tracking_server_name = _mlflow_tracking_server_name(cfg)
try:
url = mlflow.create_presigned_tracking_server_url(
cfg.aws.region,
cfg.aws.profile,
tracking_server_name,
)
except Exception as e:
CONSOLE.print("[yellow]Could not create a SageMaker MLflow UI URL.[/yellow]")
CONSOLE.print(f"Tracking server: [cyan]{tracking_server_name}[/cyan]")
CONSOLE.print(f"Reason: {e}")
CONSOLE.print(
"This command can create presigned URLs only for MLflow tracking servers managed by "
"Amazon SageMaker. If this is an external MLflow server, open it with that server's own URL."
)
raise typer.Exit(1)
CONSOLE.print(f"MLflow tracking server: [cyan]{tracking_server_name}[/cyan]")
CONSOLE.print(f"MLflow UI: {url}")
@app.command()
def destroy(
config: str = CONFIG_OPT,
@@ -210,6 +236,15 @@ def _role_name(configured_name: str, role_arn: str) -> str:
return role_arn.rsplit("/", 1)[-1]
return "-"
def _mlflow_tracking_server_name(cfg: Config) -> str:
name = cfg.effective_mlflow_tracking_server_name
if not name:
CONSOLE.print("[red]MLflow is disabled in config.yaml.[/red]")
raise typer.Exit(1)
return name
def _destroy_account_id(config_path: str, cfg: Config) -> str:
config_dir = str(Path(config_path).parent)
state = read_infra_state(config_dir)

View File

@@ -8,7 +8,7 @@ from src import state as state_ops
from src.aws import iam
from src.aws import sagemaker as sm_ops
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
from src.config import Config
from src.config import Config, MlflowMode
from src.infra.state import read_infra_state
from src.tracking.mlflow import MlflowTracker
@@ -101,6 +101,7 @@ def start(config: str = CONFIG_OPT) -> None:
CONSOLE.print(f"[green]✓[/green] Job submitted: [bold]{job_name}[/bold]")
if run_id:
CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]")
CONSOLE.print("Open MLflow: [cyan]qc-cli infra mlflow-url[/cyan]")
CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]")
@@ -137,7 +138,8 @@ def status(
run_id = job_state.get("mlflow_run_id")
already_registered = job_state.get("registered_model_version")
if run_id and not already_registered and status.status in {"Completed", "Failed", "Stopped"}:
version = _tracker(cfg).finalize_training_run(
tracker = _tracker(cfg)
version = tracker.finalize_training_run(
run_id=str(run_id),
training_job_status=status,
)
@@ -148,6 +150,8 @@ def status(
if version:
st.set_latest_prerelease_model_version(version)
CONSOLE.print(f"MLflow model version: [cyan]{version}[/cyan] ([cyan]prerelease-latest[/cyan])")
if run_id and cfg.mlflow.mode is not MlflowMode.disabled:
CONSOLE.print("Open MLflow: [cyan]qc-cli infra mlflow-url[/cyan]")
@app.command(name="list")

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import os
from dataclasses import dataclass
from typing import Any, Protocol
@@ -40,6 +41,8 @@ class MlflowTracker:
if cfg.mlflow.mode is MlflowMode.disabled:
return NoopTracker()
os.environ.setdefault("MLFLOW_SUPPRESS_PRINTING_URL_TO_STDOUT", "true")
try:
import mlflow
except ImportError as e: