Mlflow implementation (#2)
Reviewed-on: #2
This commit was merged in pull request #2.
This commit is contained in:
@@ -8,8 +8,9 @@ from src import state as state_ops
|
||||
from src.aws import iam
|
||||
from src.aws import sagemaker as sm_ops
|
||||
from src.commands.utils import CONFIG_OPT, CONSOLE, load_cfg
|
||||
from src.config import Config
|
||||
from src.config import Config, MlflowMode
|
||||
from src.infra.state import read_infra_state
|
||||
from src.tracking.mlflow import MlflowTracker
|
||||
|
||||
app = typer.Typer(help="Manage SageMaker training jobs")
|
||||
|
||||
@@ -22,6 +23,14 @@ _STATUS_COLOR = {
|
||||
}
|
||||
|
||||
|
||||
def _tracker(cfg):
|
||||
try:
|
||||
return MlflowTracker.from_config(cfg)
|
||||
except Exception as e:
|
||||
CONSOLE.print(f"[red]MLflow setup failed: {e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
def _config_dir(config_path: str) -> str:
|
||||
return str(Path(config_path).parent)
|
||||
|
||||
@@ -58,6 +67,7 @@ def start(config: str = CONFIG_OPT) -> None:
|
||||
CONSOLE.print(f"[red]{e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
tracker = _tracker(cfg)
|
||||
job_name = f"qc-cli-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
|
||||
s3_train_uri = f"s3://{cfg.s3.bucket}/{cfg.s3.data_prefix}"
|
||||
s3_output = f"s3://{cfg.s3.bucket}/{cfg.s3.model_prefix}"
|
||||
@@ -77,9 +87,21 @@ def start(config: str = CONFIG_OPT) -> None:
|
||||
)
|
||||
sm_ops.start_training_job(cfg.aws.boto3_session, training_job)
|
||||
|
||||
state_ops.write_state(_config_dir(config), last_training_job=job_name)
|
||||
st = state_ops.store(config)
|
||||
st.set_last_training_job(job_name)
|
||||
run_id = tracker.start_training_run(
|
||||
training_job,
|
||||
region=cfg.aws.region,
|
||||
profile=cfg.aws.profile,
|
||||
role_arn=role_arn,
|
||||
)
|
||||
if run_id:
|
||||
st.update_training_job(job_name, mlflow_run_id=run_id)
|
||||
|
||||
CONSOLE.print(f"[green]✓[/green] Job submitted: [bold]{job_name}[/bold]")
|
||||
if run_id:
|
||||
CONSOLE.print(f"MLflow run: [cyan]{run_id}[/cyan]")
|
||||
CONSOLE.print("Open MLflow: [cyan]qc-cli infra mlflow-url[/cyan]")
|
||||
CONSOLE.print("Track progress: [cyan]qc-cli train status[/cyan]")
|
||||
|
||||
|
||||
@@ -90,9 +112,10 @@ def status(
|
||||
) -> None:
|
||||
"""Show training job status."""
|
||||
cfg = load_cfg(config)
|
||||
st = state_ops.store(config)
|
||||
|
||||
if not job_name:
|
||||
job_name = state_ops.get_last_training_job(_config_dir(config))
|
||||
job_name = st.get_last_training_job()
|
||||
if not job_name:
|
||||
CONSOLE.print(
|
||||
"[red]No training job found in state. Pass a job name or run 'qc-cli train start' first.[/red]"
|
||||
@@ -111,6 +134,25 @@ def status(
|
||||
if status.failure_reason:
|
||||
CONSOLE.print(f"[red]Failure: {status.failure_reason}[/red]")
|
||||
|
||||
job_state = st.get_training_job(job_name)
|
||||
run_id = job_state.get("mlflow_run_id")
|
||||
already_registered = job_state.get("registered_model_version")
|
||||
if run_id and not already_registered and status.status in {"Completed", "Failed", "Stopped"}:
|
||||
tracker = _tracker(cfg)
|
||||
version = tracker.finalize_training_run(
|
||||
run_id=str(run_id),
|
||||
training_job_status=status,
|
||||
)
|
||||
updates = {"mlflow_finalized_status": status.status}
|
||||
if version:
|
||||
updates["registered_model_version"] = version
|
||||
st.update_training_job(job_name, **updates)
|
||||
if version:
|
||||
st.set_latest_experiment_model_version(version)
|
||||
CONSOLE.print(f"MLflow model version: [cyan]{version}[/cyan] ([cyan]experiment-latest[/cyan])")
|
||||
if run_id and cfg.mlflow.mode is not MlflowMode.disabled:
|
||||
CONSOLE.print("Open MLflow: [cyan]qc-cli infra mlflow-url[/cyan]")
|
||||
|
||||
|
||||
@app.command(name="list")
|
||||
def list_jobs(
|
||||
|
||||
Reference in New Issue
Block a user