1 Commits

Author SHA1 Message Date
3846c5d88d add aws context for MLFlow 2026-06-05 15:52:55 -04:00
4 changed files with 176 additions and 71 deletions

View File

@@ -1,3 +1,6 @@
import os
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, cast from typing import Any, cast
import boto3 import boto3
@@ -34,3 +37,38 @@ def create_presigned_tracking_server_url(region: str, profile: str, name: str) -
client = boto3.Session(profile_name=profile, region_name=region).client("sagemaker") client = boto3.Session(profile_name=profile, region_name=region).client("sagemaker")
response = client.create_presigned_mlflow_tracking_server_url(TrackingServerName=name) response = client.create_presigned_mlflow_tracking_server_url(TrackingServerName=name)
return str(response["AuthorizedUrl"]) return str(response["AuthorizedUrl"])
@contextmanager
def tracking_auth_env(profile: str, region: str) -> Generator[None]:
credentials = boto3.Session(profile_name=profile, region_name=region).get_credentials()
if credentials is None:
raise RuntimeError(f"AWS credentials could not be resolved for profile '{profile}'.")
frozen_credentials = credentials.get_frozen_credentials()
if not frozen_credentials.access_key or not frozen_credentials.secret_key:
raise RuntimeError(f"AWS credentials are incomplete for profile '{profile}'.")
env_updates = {
"AWS_PROFILE": profile,
"AWS_DEFAULT_REGION": region,
"AWS_REGION": region,
"AWS_ACCESS_KEY_ID": frozen_credentials.access_key,
"AWS_SECRET_ACCESS_KEY": frozen_credentials.secret_key,
}
if frozen_credentials.token:
env_updates["AWS_SESSION_TOKEN"] = frozen_credentials.token
restore_keys = set(env_updates) | {"AWS_SESSION_TOKEN"}
previous_env = {key: os.environ.get(key) for key in restore_keys}
try:
os.environ.update(env_updates)
if not frozen_credentials.token:
os.environ.pop("AWS_SESSION_TOKEN", None)
yield
finally:
for key, value in previous_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value

1
src/cloud/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Cloud provider adapters."""

77
src/cloud/mlflow.py Normal file
View File

@@ -0,0 +1,77 @@
from contextlib import AbstractContextManager
from dataclasses import dataclass
from typing import Any, Protocol
from src.aws import mlflow as aws_mlflow
from src.config import Config
class MlflowTrackingBackend(Protocol):
@property
def provider_name(self) -> str: ...
@property
def profile(self) -> str: ...
@property
def region(self) -> str: ...
def get_tracking_uri(self, tracking_server_name: str) -> str: ...
def auth_env(self) -> AbstractContextManager[None]: ...
def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]: ...
def training_run_tags(self, training_job: Any) -> dict[str, Any]: ...
def training_status_params(self, training_job_status: Any) -> dict[str, Any]: ...
def model_version_tags(self, training_job_status: Any) -> dict[str, Any]: ...
@dataclass(frozen=True)
class AwsMlflowTrackingBackend:
profile: str
region: str
provider_name: str = "aws"
def get_tracking_uri(self, tracking_server_name: str) -> str:
return aws_mlflow.get_tracking_server_arn(self.region, self.profile, tracking_server_name)
def auth_env(self) -> AbstractContextManager[None]:
return aws_mlflow.tracking_auth_env(self.profile, self.region)
def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]:
return {
"provider.name": self.provider_name,
"provider.region": region,
"provider.profile": profile,
"sagemaker.role_arn": role_arn,
"sagemaker.job_name": training_job.job_name,
"sagemaker.training_image": training_job.image_uri,
"sagemaker.instance_type": training_job.instance_type,
"sagemaker.instance_count": training_job.instance_count,
"sagemaker.s3_train_uri": training_job.s3_train_uri,
"sagemaker.s3_output_path": training_job.s3_output_path,
"sagemaker.entry_point": training_job.entry_point,
"sagemaker.source_dir": training_job.source_dir,
}
def training_run_tags(self, training_job: Any) -> dict[str, Any]:
return {"sagemaker.job_name": training_job.job_name}
def training_status_params(self, training_job_status: Any) -> dict[str, Any]:
return {
"sagemaker.training_status": training_job_status.status,
"sagemaker.created_at": training_job_status.created,
"sagemaker.modified_at": training_job_status.modified,
"sagemaker.model_artifacts": training_job_status.model_artifacts,
"sagemaker.failure_reason": training_job_status.failure_reason,
}
def model_version_tags(self, training_job_status: Any) -> dict[str, Any]:
return {"sagemaker.job_name": training_job_status.name}
def mlflow_tracking_backend_from_config(cfg: Config) -> MlflowTrackingBackend:
return AwsMlflowTrackingBackend(profile=cfg.aws.profile, region=cfg.aws.region)

View File

@@ -5,7 +5,7 @@ from typing import Any, Protocol
import mlflow import mlflow
from mlflow.tracking import MlflowClient from mlflow.tracking import MlflowClient
from src.aws import mlflow as aws_mlflow from src.cloud.mlflow import MlflowTrackingBackend, mlflow_tracking_backend_from_config
from src.config import Config, MlflowMode from src.config import Config, MlflowMode
@@ -30,6 +30,7 @@ class MlflowTracker:
experiment_name: str experiment_name: str
registered_model_name: str registered_model_name: str
register_trained_models: bool register_trained_models: bool
tracking_backend: MlflowTrackingBackend
@classmethod @classmethod
def from_config(cls, cfg: Config) -> Tracker: def from_config(cls, cfg: Config) -> Tracker:
@@ -42,11 +43,10 @@ class MlflowTracker:
if not tracking_server_name: if not tracking_server_name:
raise RuntimeError("MLflow tracking server name could not be resolved.") raise RuntimeError("MLflow tracking server name could not be resolved.")
tracking_uri = aws_mlflow.get_tracking_server_arn( tracking_backend = mlflow_tracking_backend_from_config(cfg)
cfg.aws.region,
cfg.aws.profile, tracking_uri = tracking_backend.get_tracking_uri(tracking_server_name)
tracking_server_name, with tracking_backend.auth_env():
)
mlflow.set_tracking_uri(tracking_uri) mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment(cfg.mlflow.experiment_name) mlflow.set_experiment(cfg.mlflow.experiment_name)
@@ -55,34 +55,30 @@ class MlflowTracker:
experiment_name=cfg.mlflow.experiment_name, experiment_name=cfg.mlflow.experiment_name,
registered_model_name=cfg.mlflow.registered_model_name, registered_model_name=cfg.mlflow.registered_model_name,
register_trained_models=cfg.mlflow.register_trained_models, register_trained_models=cfg.mlflow.register_trained_models,
tracking_backend=tracking_backend,
) )
def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None:
with self.tracking_backend.auth_env():
run = mlflow.start_run(run_name=training_job.job_name) run = mlflow.start_run(run_name=training_job.job_name)
run_id = str(run.info.run_id) run_id = str(run.info.run_id)
params = { self._log_params(
"aws.region": region, self.tracking_backend.training_run_params(
"aws.profile": profile, training_job,
"sagemaker.role_arn": role_arn, region=region,
"sagemaker.job_name": training_job.job_name, profile=profile,
"sagemaker.training_image": training_job.image_uri, role_arn=role_arn,
"sagemaker.instance_type": training_job.instance_type, )
"sagemaker.instance_count": training_job.instance_count, )
"sagemaker.s3_train_uri": training_job.s3_train_uri,
"sagemaker.s3_output_path": training_job.s3_output_path,
"sagemaker.entry_point": training_job.entry_point,
"sagemaker.source_dir": training_job.source_dir,
}
self._log_params(params)
self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()}) self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()})
mlflow.set_tags( mlflow.set_tags(
{ {
"qc_cli.stage": "experiment", "qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source", "qc_cli.artifact_kind": "trained_source",
"qc_cli.source": "sagemaker", "qc_cli.source": self.tracking_backend.provider_name,
"qc_cli.command": "train start", "qc_cli.command": "train start",
"sagemaker.job_name": training_job.job_name, **self.tracking_backend.training_run_tags(training_job),
} }
) )
mlflow.end_run() mlflow.end_run()
@@ -92,16 +88,9 @@ class MlflowTracker:
if not run_id: if not run_id:
return None return None
with self.tracking_backend.auth_env():
with mlflow.start_run(run_id=run_id): with mlflow.start_run(run_id=run_id):
self._log_params( self._log_params(self.tracking_backend.training_status_params(training_job_status))
{
"sagemaker.training_status": training_job_status.status,
"sagemaker.created_at": training_job_status.created,
"sagemaker.modified_at": training_job_status.modified,
"sagemaker.model_artifacts": training_job_status.model_artifacts,
"sagemaker.failure_reason": training_job_status.failure_reason,
}
)
self._log_final_metrics(training_job_status.raw) self._log_final_metrics(training_job_status.raw)
mlflow.set_tag("qc_cli.command", "train status") mlflow.set_tag("qc_cli.command", "train status")
@@ -121,8 +110,8 @@ class MlflowTracker:
tags={ tags={
"qc_cli.stage": "experiment", "qc_cli.stage": "experiment",
"qc_cli.artifact_kind": "trained_source", "qc_cli.artifact_kind": "trained_source",
"qc_cli.source": "sagemaker", "qc_cli.source": self.tracking_backend.provider_name,
"sagemaker.job_name": training_job_status.name, **self.tracking_backend.model_version_tags(training_job_status),
}, },
) )
version_number = str(version.version) version_number = str(version.version)