Files
qai-cli/src/tracking/mlflow.py
2026-06-12 18:23:25 +00:00

268 lines
9.9 KiB
Python

import os
import tempfile
from dataclasses import dataclass
from typing import Any, Protocol
import mlflow
from mlflow.tracking import MlflowClient
from src.aws import s3
from src.cloud.mlflow import MlflowTrackingBackend, mlflow_tracking_backend_from_config
from src.config import Config, MlflowMode
from src.tracking.metrics import METRICS_ARTIFACT_NAME, parse_training_metrics, read_training_metrics_from_tar
@dataclass(frozen=True)
class FinalizeResult:
registered_model_version: str | None = None
warnings: tuple[str, ...] = ()
class Tracker(Protocol):
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: ...
def finalize_training_run(
self,
*,
run_id: str | None,
training_job_status: Any,
region: str,
profile: str,
command: str,
) -> FinalizeResult: ...
def ensure_training_run(self, job_name: str) -> str: ...
def upload_training_metrics(
self,
*,
run_id: str,
training_job_status: Any,
region: str,
profile: str,
) -> bool: ...
@dataclass(frozen=True)
class NoopTracker:
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
return None
def finalize_training_run(
self,
*,
run_id: str | None,
training_job_status: Any,
region: str,
profile: str,
command: str,
) -> FinalizeResult:
return FinalizeResult()
def ensure_training_run(self, job_name: str) -> str:
raise RuntimeError("MLflow is disabled.")
def upload_training_metrics(
self,
*,
run_id: str,
training_job_status: Any,
region: str,
profile: str,
) -> bool:
raise RuntimeError("MLflow is disabled.")
@dataclass(frozen=True)
class MlflowTracker:
tracking_uri: str
experiment_name: str
registered_model_name: str
register_trained_models: bool
tracking_backend: MlflowTrackingBackend
@classmethod
def from_config(cls, cfg: Config) -> Tracker:
if cfg.mlflow.mode is MlflowMode.disabled:
return NoopTracker()
os.environ.setdefault("MLFLOW_SUPPRESS_PRINTING_URL_TO_STDOUT", "true")
tracking_server_name = cfg.effective_mlflow_tracking_server_name
if not tracking_server_name:
raise RuntimeError("MLflow tracking server name could not be resolved.")
tracking_backend = mlflow_tracking_backend_from_config(cfg)
tracking_uri = tracking_backend.get_tracking_uri(tracking_server_name)
with tracking_backend.auth_env():
mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment(cfg.mlflow.experiment_name)
return cls(
tracking_uri=tracking_uri,
experiment_name=cfg.mlflow.experiment_name,
registered_model_name=cfg.mlflow.registered_model_name,
register_trained_models=cfg.mlflow.register_trained_models,
tracking_backend=tracking_backend,
)
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
with self.tracking_backend.auth_env():
with mlflow.start_run(run_name=training_job.job_name) as run:
run_id = str(run.info.run_id)
self._log_params(
self.tracking_backend.training_run_params(
training_job,
region=region,
profile=profile,
role_arn=role_arn,
)
)
self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()})
mlflow.set_tags(
{
"qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source",
"qc_cli.source": self.tracking_backend.provider_name,
"qc_cli.command": "train start",
**self.tracking_backend.training_run_tags(training_job),
}
)
return run_id
def finalize_training_run(
self,
*,
run_id: str | None,
training_job_status: Any,
region: str,
profile: str,
command: str,
) -> FinalizeResult:
if not run_id:
return FinalizeResult()
with self.tracking_backend.auth_env():
with mlflow.start_run(run_id=run_id):
self._log_params(self.tracking_backend.training_status_params(training_job_status))
self._log_final_metrics(training_job_status.raw)
mlflow.set_tag("qc_cli.command", command)
if training_job_status.status != "Completed" or not training_job_status.model_artifacts:
mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status)
return FinalizeResult()
if not self.register_trained_models:
return FinalizeResult()
client = MlflowClient()
self._ensure_registered_model(client, self.registered_model_name)
version = client.create_model_version(
name=self.registered_model_name,
source=training_job_status.model_artifacts,
run_id=run_id,
tags={
"qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source",
"qc_cli.source": self.tracking_backend.provider_name,
**self.tracking_backend.model_version_tags(training_job_status),
},
)
version_number = str(version.version)
client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number)
mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name)
mlflow.set_tag("qc_cli.registered_model_version", version_number)
return FinalizeResult(registered_model_version=version_number)
def ensure_training_run(self, job_name: str) -> str:
with self.tracking_backend.auth_env():
client = MlflowClient()
experiment = client.get_experiment_by_name(self.experiment_name)
if experiment is None:
experiment_id = mlflow.create_experiment(self.experiment_name)
else:
experiment_id = experiment.experiment_id
for run in client.search_runs([experiment_id], max_results=1000):
if run.data.tags.get("sagemaker.job_name") == job_name:
return str(run.info.run_id)
run = client.create_run(
experiment_id,
run_name=job_name,
tags={
"qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source",
"qc_cli.source": self.tracking_backend.provider_name,
"qc_cli.command": "mlflow upload-metrics",
"sagemaker.job_name": job_name,
},
)
return str(run.info.run_id)
def upload_training_metrics(
self,
*,
run_id: str,
training_job_status: Any,
region: str,
profile: str,
) -> bool:
if not training_job_status.model_artifacts:
raise ValueError(f"Training job '{training_job_status.name}' has no model artifacts.")
with self.tracking_backend.auth_env():
with mlflow.start_run(run_id=run_id):
self._log_params(self.tracking_backend.training_status_params(training_job_status))
self._log_final_metrics(training_job_status.raw)
history_uploaded = self._log_training_metrics(
training_job_status.model_artifacts,
region=region,
profile=profile,
)
mlflow.set_tag("qc_cli.command", "mlflow upload-metrics")
mlflow.set_tag("qc_cli.metrics_history_uploaded", str(history_uploaded).lower())
return history_uploaded
def _log_params(self, params: dict[str, Any]) -> None:
cleaned = {key: str(value) for key, value in params.items() if value is not None}
if cleaned:
mlflow.log_params(cleaned)
def _log_final_metrics(self, training_job: dict[str, Any]) -> None:
metrics = {}
for metric in training_job.get("FinalMetricDataList", []):
name = metric.get("MetricName")
value = metric.get("Value")
if name and value is not None:
metrics[str(name)] = float(value)
if metrics:
mlflow.log_metrics(metrics)
def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> bool:
with tempfile.TemporaryDirectory(prefix="qc-cli-metrics-") as temp_dir:
archive_path = s3.download_file(
region,
profile,
model_artifacts,
os.path.join(temp_dir, "model.tar.gz"),
)
metrics_data = read_training_metrics_from_tar(archive_path)
if metrics_data is None:
return False
metrics = parse_training_metrics(metrics_data)
for metric_step in metrics.steps:
if metric_step.metrics:
mlflow.log_metrics(metric_step.metrics, step=metric_step.step)
if metrics.summary:
mlflow.log_metrics(metrics.summary)
mlflow.log_dict(metrics.raw, METRICS_ARTIFACT_NAME)
return True
def _ensure_registered_model(self, client: MlflowClient, name: str) -> None:
try:
client.get_registered_model(name)
except Exception:
client.create_registered_model(name)