command to create presigned URL for MLFlow
This commit is contained in:
@@ -28,3 +28,9 @@ def get_tracking_server_arn(region: str, profile: str, name: str) -> str:
|
||||
if not arn:
|
||||
raise ValueError(f"MLflow tracking server has no ARN: {name}")
|
||||
return str(arn)
|
||||
|
||||
|
||||
def create_presigned_tracking_server_url(region: str, profile: str, name: str) -> str:
|
||||
client = boto3.Session(profile_name=profile, region_name=region).client("sagemaker")
|
||||
response = client.create_presigned_mlflow_tracking_server_url(TrackingServerName=name)
|
||||
return str(response["AuthorizedUrl"])
|
||||
|
||||
@@ -150,6 +150,32 @@ def status(config: str = CONFIG_OPT) -> None:
|
||||
CONSOLE.print(table)
|
||||
|
||||
|
||||
@app.command(name="mlflow-url")
|
||||
def mlflow_url(config: str = CONFIG_OPT) -> None:
|
||||
"""Print a presigned URL for the configured MLflow tracking server."""
|
||||
cfg = load_cfg(config)
|
||||
tracking_server_name = _mlflow_tracking_server_name(cfg)
|
||||
|
||||
try:
|
||||
url = mlflow.create_presigned_tracking_server_url(
|
||||
cfg.aws.region,
|
||||
cfg.aws.profile,
|
||||
tracking_server_name,
|
||||
)
|
||||
except Exception as e:
|
||||
CONSOLE.print("[yellow]Could not create a SageMaker MLflow UI URL.[/yellow]")
|
||||
CONSOLE.print(f"Tracking server: [cyan]{tracking_server_name}[/cyan]")
|
||||
CONSOLE.print(f"Reason: {e}")
|
||||
CONSOLE.print(
|
||||
"This command can create presigned URLs only for MLflow tracking servers managed by "
|
||||
"Amazon SageMaker. If this is an external MLflow server, open it with that server's own URL."
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
CONSOLE.print(f"MLflow tracking server: [cyan]{tracking_server_name}[/cyan]")
|
||||
CONSOLE.print(f"MLflow UI: {url}")
|
||||
|
||||
|
||||
@app.command()
|
||||
def destroy(
|
||||
config: str = CONFIG_OPT,
|
||||
@@ -210,6 +236,15 @@ def _role_name(configured_name: str, role_arn: str) -> str:
|
||||
return role_arn.rsplit("/", 1)[-1]
|
||||
return "-"
|
||||
|
||||
|
||||
def _mlflow_tracking_server_name(cfg: Config) -> str:
|
||||
name = cfg.effective_mlflow_tracking_server_name
|
||||
if not name:
|
||||
CONSOLE.print("[red]MLflow is disabled in config.yaml.[/red]")
|
||||
raise typer.Exit(1)
|
||||
return name
|
||||
|
||||
|
||||
def _destroy_account_id(config_path: str, cfg: Config) -> str:
|
||||
config_dir = str(Path(config_path).parent)
|
||||
state = read_infra_state(config_dir)
|
||||
|
||||
@@ -8,7 +8,7 @@ from src import state as state_ops
|
||||
from src.aws import iam
|
||||
from src.aws import sagemaker as sm_ops
|
||||
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
||||
from src.config import Config
|
||||
from src.config import Config, MlflowMode
|
||||
from src.infra.state import read_infra_state
|
||||
from src.tracking.mlflow import MlflowTracker
|
||||
|
||||
@@ -101,6 +101,7 @@ def start(config: str = CONFIG_OPT) -> None:
|
||||
CONSOLE.print(f"[green]✓[/green] Job submitted: [bold]{job_name}[/bold]")
|
||||
if run_id:
|
||||
CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]")
|
||||
CONSOLE.print("Open MLflow: [cyan]qc-cli infra mlflow-url[/cyan]")
|
||||
CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]")
|
||||
|
||||
|
||||
@@ -137,7 +138,8 @@ def status(
|
||||
run_id = job_state.get("mlflow_run_id")
|
||||
already_registered = job_state.get("registered_model_version")
|
||||
if run_id and not already_registered and status.status in {"Completed", "Failed", "Stopped"}:
|
||||
version = _tracker(cfg).finalize_training_run(
|
||||
tracker = _tracker(cfg)
|
||||
version = tracker.finalize_training_run(
|
||||
run_id=str(run_id),
|
||||
training_job_status=status,
|
||||
)
|
||||
@@ -148,6 +150,8 @@ def status(
|
||||
if version:
|
||||
st.set_latest_prerelease_model_version(version)
|
||||
CONSOLE.print(f"MLflow model version: [cyan]{version}[/cyan] ([cyan]prerelease-latest[/cyan])")
|
||||
if run_id and cfg.mlflow.mode is not MlflowMode.disabled:
|
||||
CONSOLE.print("Open MLflow: [cyan]qc-cli infra mlflow-url[/cyan]")
|
||||
|
||||
|
||||
@app.command(name="list")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
|
||||
@@ -40,6 +41,8 @@ class MlflowTracker:
|
||||
if cfg.mlflow.mode is MlflowMode.disabled:
|
||||
return NoopTracker()
|
||||
|
||||
os.environ.setdefault("MLFLOW_SUPPRESS_PRINTING_URL_TO_STDOUT", "true")
|
||||
|
||||
try:
|
||||
import mlflow
|
||||
except ImportError as e:
|
||||
|
||||
Reference in New Issue
Block a user