From 3846c5d88d78759ba97b7257f80e4a39666a6ea6 Mon Sep 17 00:00:00 2001 From: slalom Date: Fri, 5 Jun 2026 15:52:55 -0400 Subject: [PATCH] add aws context for MLFlow --- src/aws/mlflow.py | 38 ++++++++++++ src/cloud/__init__.py | 1 + src/cloud/mlflow.py | 77 ++++++++++++++++++++++++ src/tracking/mlflow.py | 131 +++++++++++++++++++---------------------- 4 files changed, 176 insertions(+), 71 deletions(-) create mode 100644 src/cloud/__init__.py create mode 100644 src/cloud/mlflow.py diff --git a/src/aws/mlflow.py b/src/aws/mlflow.py index 35433bb..344e70f 100644 --- a/src/aws/mlflow.py +++ b/src/aws/mlflow.py @@ -1,3 +1,6 @@ +import os +from collections.abc import Generator +from contextlib import contextmanager from typing import Any, cast import boto3 @@ -34,3 +37,38 @@ def create_presigned_tracking_server_url(region: str, profile: str, name: str) - client = boto3.Session(profile_name=profile, region_name=region).client("sagemaker") response = client.create_presigned_mlflow_tracking_server_url(TrackingServerName=name) return str(response["AuthorizedUrl"]) + + +@contextmanager +def tracking_auth_env(profile: str, region: str) -> Generator[None]: + credentials = boto3.Session(profile_name=profile, region_name=region).get_credentials() + if credentials is None: + raise RuntimeError(f"AWS credentials could not be resolved for profile '{profile}'.") + + frozen_credentials = credentials.get_frozen_credentials() + if not frozen_credentials.access_key or not frozen_credentials.secret_key: + raise RuntimeError(f"AWS credentials are incomplete for profile '{profile}'.") + + env_updates = { + "AWS_PROFILE": profile, + "AWS_DEFAULT_REGION": region, + "AWS_REGION": region, + "AWS_ACCESS_KEY_ID": frozen_credentials.access_key, + "AWS_SECRET_ACCESS_KEY": frozen_credentials.secret_key, + } + if frozen_credentials.token: + env_updates["AWS_SESSION_TOKEN"] = frozen_credentials.token + + restore_keys = set(env_updates) | {"AWS_SESSION_TOKEN"} + previous_env = {key: os.environ.get(key) for key in restore_keys} + try: + os.environ.update(env_updates) + if not frozen_credentials.token: + os.environ.pop("AWS_SESSION_TOKEN", None) + yield + finally: + for key, value in previous_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value diff --git a/src/cloud/__init__.py b/src/cloud/__init__.py new file mode 100644 index 0000000..75bf900 --- /dev/null +++ b/src/cloud/__init__.py @@ -0,0 +1 @@ +"""Cloud provider adapters.""" diff --git a/src/cloud/mlflow.py b/src/cloud/mlflow.py new file mode 100644 index 0000000..08ae203 --- /dev/null +++ b/src/cloud/mlflow.py @@ -0,0 +1,77 @@ +from contextlib import AbstractContextManager +from dataclasses import dataclass +from typing import Any, Protocol + +from src.aws import mlflow as aws_mlflow +from src.config import Config + + +class MlflowTrackingBackend(Protocol): + @property + def provider_name(self) -> str: ... + + @property + def profile(self) -> str: ... + + @property + def region(self) -> str: ... + + def get_tracking_uri(self, tracking_server_name: str) -> str: ... + + def auth_env(self) -> AbstractContextManager[None]: ... + + def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]: ... + + def training_run_tags(self, training_job: Any) -> dict[str, Any]: ... + + def training_status_params(self, training_job_status: Any) -> dict[str, Any]: ... + + def model_version_tags(self, training_job_status: Any) -> dict[str, Any]: ... + + +@dataclass(frozen=True) +class AwsMlflowTrackingBackend: + profile: str + region: str + provider_name: str = "aws" + + def get_tracking_uri(self, tracking_server_name: str) -> str: + return aws_mlflow.get_tracking_server_arn(self.region, self.profile, tracking_server_name) + + def auth_env(self) -> AbstractContextManager[None]: + return aws_mlflow.tracking_auth_env(self.profile, self.region) + + def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]: + return { + "provider.name": self.provider_name, + "provider.region": region, + "provider.profile": profile, + "sagemaker.role_arn": role_arn, + "sagemaker.job_name": training_job.job_name, + "sagemaker.training_image": training_job.image_uri, + "sagemaker.instance_type": training_job.instance_type, + "sagemaker.instance_count": training_job.instance_count, + "sagemaker.s3_train_uri": training_job.s3_train_uri, + "sagemaker.s3_output_path": training_job.s3_output_path, + "sagemaker.entry_point": training_job.entry_point, + "sagemaker.source_dir": training_job.source_dir, + } + + def training_run_tags(self, training_job: Any) -> dict[str, Any]: + return {"sagemaker.job_name": training_job.job_name} + + def training_status_params(self, training_job_status: Any) -> dict[str, Any]: + return { + "sagemaker.training_status": training_job_status.status, + "sagemaker.created_at": training_job_status.created, + "sagemaker.modified_at": training_job_status.modified, + "sagemaker.model_artifacts": training_job_status.model_artifacts, + "sagemaker.failure_reason": training_job_status.failure_reason, + } + + def model_version_tags(self, training_job_status: Any) -> dict[str, Any]: + return {"sagemaker.job_name": training_job_status.name} + + +def mlflow_tracking_backend_from_config(cfg: Config) -> MlflowTrackingBackend: + return AwsMlflowTrackingBackend(profile=cfg.aws.profile, region=cfg.aws.region) diff --git a/src/tracking/mlflow.py b/src/tracking/mlflow.py index 0e8f5d0..2483870 100644 --- a/src/tracking/mlflow.py +++ b/src/tracking/mlflow.py @@ -5,7 +5,7 @@ from typing import Any, Protocol import mlflow from mlflow.tracking import MlflowClient -from src.aws import mlflow as aws_mlflow +from src.cloud.mlflow import MlflowTrackingBackend, mlflow_tracking_backend_from_config from src.config import Config, MlflowMode @@ -30,6 +30,7 @@ class MlflowTracker: experiment_name: str registered_model_name: str register_trained_models: bool + tracking_backend: MlflowTrackingBackend @classmethod def from_config(cls, cfg: Config) -> Tracker: @@ -42,94 +43,82 @@ class MlflowTracker: if not tracking_server_name: raise RuntimeError("MLflow tracking server name could not be resolved.") - tracking_uri = aws_mlflow.get_tracking_server_arn( - cfg.aws.region, - cfg.aws.profile, - tracking_server_name, - ) - mlflow.set_tracking_uri(tracking_uri) - mlflow.set_experiment(cfg.mlflow.experiment_name) + tracking_backend = mlflow_tracking_backend_from_config(cfg) + + tracking_uri = tracking_backend.get_tracking_uri(tracking_server_name) + with tracking_backend.auth_env(): + mlflow.set_tracking_uri(tracking_uri) + mlflow.set_experiment(cfg.mlflow.experiment_name) return cls( tracking_uri=tracking_uri, experiment_name=cfg.mlflow.experiment_name, registered_model_name=cfg.mlflow.registered_model_name, register_trained_models=cfg.mlflow.register_trained_models, + tracking_backend=tracking_backend, ) def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: - run = mlflow.start_run(run_name=training_job.job_name) - run_id = str(run.info.run_id) + with self.tracking_backend.auth_env(): + run = mlflow.start_run(run_name=training_job.job_name) + run_id = str(run.info.run_id) - params = { - "aws.region": region, - "aws.profile": profile, - "sagemaker.role_arn": role_arn, - "sagemaker.job_name": training_job.job_name, - "sagemaker.training_image": training_job.image_uri, - "sagemaker.instance_type": training_job.instance_type, - "sagemaker.instance_count": training_job.instance_count, - "sagemaker.s3_train_uri": training_job.s3_train_uri, - "sagemaker.s3_output_path": training_job.s3_output_path, - "sagemaker.entry_point": training_job.entry_point, - "sagemaker.source_dir": training_job.source_dir, - } - self._log_params(params) - self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()}) - mlflow.set_tags( - { - "qc_cli.stage": "experiment", - "qc_cli.artifact_kind": "trained_source", - "qc_cli.source": "sagemaker", - "qc_cli.command": "train start", - "sagemaker.job_name": training_job.job_name, - } - ) - mlflow.end_run() - return run_id + self._log_params( + self.tracking_backend.training_run_params( + training_job, + region=region, + profile=profile, + role_arn=role_arn, + ) + ) + self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()}) + mlflow.set_tags( + { + "qc_cli.stage": "experiment", + "qc_cli.artifact_kind": "trained_source", + "qc_cli.source": self.tracking_backend.provider_name, + "qc_cli.command": "train start", + **self.tracking_backend.training_run_tags(training_job), + } + ) + mlflow.end_run() + return run_id def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None: if not run_id: return None - with mlflow.start_run(run_id=run_id): - self._log_params( - { - "sagemaker.training_status": training_job_status.status, - "sagemaker.created_at": training_job_status.created, - "sagemaker.modified_at": training_job_status.modified, - "sagemaker.model_artifacts": training_job_status.model_artifacts, - "sagemaker.failure_reason": training_job_status.failure_reason, - } - ) - self._log_final_metrics(training_job_status.raw) - mlflow.set_tag("qc_cli.command", "train status") + with self.tracking_backend.auth_env(): + with mlflow.start_run(run_id=run_id): + self._log_params(self.tracking_backend.training_status_params(training_job_status)) + self._log_final_metrics(training_job_status.raw) + mlflow.set_tag("qc_cli.command", "train status") - if training_job_status.status != "Completed" or not training_job_status.model_artifacts: - mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status) - return None + if training_job_status.status != "Completed" or not training_job_status.model_artifacts: + mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status) + return None - if not self.register_trained_models: - return None + if not self.register_trained_models: + return None - client = MlflowClient() - self._ensure_registered_model(client, self.registered_model_name) - version = client.create_model_version( - name=self.registered_model_name, - source=training_job_status.model_artifacts, - run_id=run_id, - tags={ - "qc_cli.stage": "experiment", - "qc_cli.artifact_kind": "trained_source", - "qc_cli.source": "sagemaker", - "sagemaker.job_name": training_job_status.name, - }, - ) - version_number = str(version.version) - client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number) - mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name) - mlflow.set_tag("qc_cli.registered_model_version", version_number) - return version_number + client = MlflowClient() + self._ensure_registered_model(client, self.registered_model_name) + version = client.create_model_version( + name=self.registered_model_name, + source=training_job_status.model_artifacts, + run_id=run_id, + tags={ + "qc_cli.stage": "experiment", + "qc_cli.artifact_kind": "trained_source", + "qc_cli.source": self.tracking_backend.provider_name, + **self.tracking_backend.model_version_tags(training_job_status), + }, + ) + version_number = str(version.version) + client.set_registered_model_alias(self.registered_model_name, "experiment-latest", version_number) + mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name) + mlflow.set_tag("qc_cli.registered_model_version", version_number) + return version_number def _log_params(self, params: dict[str, Any]) -> None: cleaned = {key: str(value) for key, value in params.items() if value is not None}