import os import tempfile from dataclasses import dataclass from typing import Any, Protocol import mlflow from mlflow.tracking import MlflowClient from src.aws import s3 from src.cloud.mlflow import MlflowTrackingBackend, mlflow_tracking_backend_from_config from src.config import Config, MlflowMode from src.tracking.metrics import METRICS_ARTIFACT_NAME, parse_training_metrics, read_training_metrics_from_tar @dataclass(frozen=True) class FinalizeResult: registered_model_version: str | None = None warnings: tuple[str, ...] = () class Tracker(Protocol): def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: ... def finalize_training_run( self, *, run_id: str | None, training_job_status: Any, region: str, profile: str, command: str, ) -> FinalizeResult: ... def ensure_training_run(self, job_name: str) -> str: ... def upload_training_metrics( self, *, run_id: str, training_job_status: Any, region: str, profile: str, ) -> bool: ... @dataclass(frozen=True) class NoopTracker: def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: return None def finalize_training_run( self, *, run_id: str | None, training_job_status: Any, region: str, profile: str, command: str, ) -> FinalizeResult: return FinalizeResult() def ensure_training_run(self, job_name: str) -> str: raise RuntimeError("MLflow is disabled.") def upload_training_metrics( self, *, run_id: str, training_job_status: Any, region: str, profile: str, ) -> bool: raise RuntimeError("MLflow is disabled.") @dataclass(frozen=True) class MlflowTracker: tracking_uri: str experiment_name: str registered_model_name: str register_trained_models: bool tracking_backend: MlflowTrackingBackend @classmethod def from_config(cls, cfg: Config) -> Tracker: if cfg.mlflow.mode is MlflowMode.disabled: return NoopTracker() os.environ.setdefault("MLFLOW_SUPPRESS_PRINTING_URL_TO_STDOUT", "true") tracking_server_name = cfg.effective_mlflow_tracking_server_name if not tracking_server_name: raise RuntimeError("MLflow tracking server name could not be resolved.") tracking_backend = mlflow_tracking_backend_from_config(cfg) tracking_uri = tracking_backend.get_tracking_uri(tracking_server_name) with tracking_backend.auth_env(): mlflow.set_tracking_uri(tracking_uri) mlflow.set_experiment(cfg.mlflow.experiment_name) return cls( tracking_uri=tracking_uri, experiment_name=cfg.mlflow.experiment_name, registered_model_name=cfg.mlflow.registered_model_name, register_trained_models=cfg.mlflow.register_trained_models, tracking_backend=tracking_backend, ) def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: with self.tracking_backend.auth_env(): with mlflow.start_run(run_name=training_job.job_name) as run: run_id = str(run.info.run_id) self._log_params( self.tracking_backend.training_run_params( training_job, region=region, profile=profile, role_arn=role_arn, ) ) self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()}) mlflow.set_tags( { "qc_cli.stage": "experiment", "qc_cli.artifact_kind": "trained_source", "qc_cli.source": self.tracking_backend.provider_name, "qc_cli.command": "train start", **self.tracking_backend.training_run_tags(training_job), } ) return run_id def finalize_training_run( self, *, run_id: str | None, training_job_status: Any, region: str, profile: str, command: str, ) -> FinalizeResult: if not run_id: return FinalizeResult() with self.tracking_backend.auth_env(): with mlflow.start_run(run_id=run_id): self._log_params(self.tracking_backend.training_status_params(training_job_status)) self._log_final_metrics(training_job_status.raw) mlflow.set_tag("qc_cli.command", command) if training_job_status.status != "Completed" or not training_job_status.model_artifacts: mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status) return FinalizeResult() if not self.register_trained_models: return FinalizeResult() client = MlflowClient() self._ensure_registered_model(client, self.registered_model_name) version = client.create_model_version( name=self.registered_model_name, source=training_job_status.model_artifacts, run_id=run_id, tags={ "qc_cli.stage": "experiment", "qc_cli.artifact_kind": "trained_source", "qc_cli.source": self.tracking_backend.provider_name, **self.tracking_backend.model_version_tags(training_job_status), }, ) version_number = str(version.version) client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number) mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name) mlflow.set_tag("qc_cli.registered_model_version", version_number) return FinalizeResult(registered_model_version=version_number) def ensure_training_run(self, job_name: str) -> str: with self.tracking_backend.auth_env(): client = MlflowClient() experiment = client.get_experiment_by_name(self.experiment_name) if experiment is None: experiment_id = mlflow.create_experiment(self.experiment_name) else: experiment_id = experiment.experiment_id for run in client.search_runs([experiment_id], max_results=1000): if run.data.tags.get("sagemaker.job_name") == job_name: return str(run.info.run_id) run = client.create_run( experiment_id, run_name=job_name, tags={ "qc_cli.stage": "experiment", "qc_cli.artifact_kind": "trained_source", "qc_cli.source": self.tracking_backend.provider_name, "qc_cli.command": "mlflow upload-metrics", "sagemaker.job_name": job_name, }, ) return str(run.info.run_id) def upload_training_metrics( self, *, run_id: str, training_job_status: Any, region: str, profile: str, ) -> bool: if not training_job_status.model_artifacts: raise ValueError(f"Training job '{training_job_status.name}' has no model artifacts.") with self.tracking_backend.auth_env(): with mlflow.start_run(run_id=run_id): self._log_params(self.tracking_backend.training_status_params(training_job_status)) self._log_final_metrics(training_job_status.raw) history_uploaded = self._log_training_metrics( training_job_status.model_artifacts, region=region, profile=profile, ) mlflow.set_tag("qc_cli.command", "mlflow upload-metrics") mlflow.set_tag("qc_cli.metrics_history_uploaded", str(history_uploaded).lower()) return history_uploaded def _log_params(self, params: dict[str, Any]) -> None: cleaned = {key: str(value) for key, value in params.items() if value is not None} if cleaned: mlflow.log_params(cleaned) def _log_final_metrics(self, training_job: dict[str, Any]) -> None: metrics = {} for metric in training_job.get("FinalMetricDataList", []): name = metric.get("MetricName") value = metric.get("Value") if name and value is not None: metrics[str(name)] = float(value) if metrics: mlflow.log_metrics(metrics) def _log_training_metrics(self, model_artifacts: str, *, region: str, profile: str) -> bool: with tempfile.TemporaryDirectory(prefix="qc-cli-metrics-") as temp_dir: archive_path = s3.download_file( region, profile, model_artifacts, os.path.join(temp_dir, "model.tar.gz"), ) metrics_data = read_training_metrics_from_tar(archive_path) if metrics_data is None: return False metrics = parse_training_metrics(metrics_data) for metric_step in metrics.steps: if metric_step.metrics: mlflow.log_metrics(metric_step.metrics, step=metric_step.step) if metrics.summary: mlflow.log_metrics(metrics.summary) mlflow.log_dict(metrics.raw, METRICS_ARTIFACT_NAME) return True def _ensure_registered_model(self, client: MlflowClient, name: str) -> None: try: client.get_registered_model(name) except Exception: client.create_registered_model(name)