from contextlib import AbstractContextManager from dataclasses import dataclass from typing import Any, Protocol from src.aws import mlflow as aws_mlflow from src.config import Config class MlflowTrackingBackend(Protocol): @property def provider_name(self) -> str: ... @property def profile(self) -> str: ... @property def region(self) -> str: ... def get_tracking_uri(self, tracking_server_name: str) -> str: ... def auth_env(self) -> AbstractContextManager[None]: ... def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]: ... def training_run_tags(self, training_job: Any) -> dict[str, Any]: ... def training_status_params(self, training_job_status: Any) -> dict[str, Any]: ... def model_version_tags(self, training_job_status: Any) -> dict[str, Any]: ... @dataclass(frozen=True) class AwsMlflowTrackingBackend: profile: str region: str provider_name: str = "aws" def get_tracking_uri(self, tracking_server_name: str) -> str: return aws_mlflow.get_tracking_server_arn(self.region, self.profile, tracking_server_name) def auth_env(self) -> AbstractContextManager[None]: return aws_mlflow.tracking_auth_env(self.profile, self.region) def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]: return { "provider.name": self.provider_name, "provider.region": region, "provider.profile": profile, "sagemaker.role_arn": role_arn, "sagemaker.job_name": training_job.job_name, "sagemaker.training_image": training_job.image_uri, "sagemaker.instance_type": training_job.instance_type, "sagemaker.instance_count": training_job.instance_count, "sagemaker.s3_train_uri": training_job.s3_train_uri, "sagemaker.s3_output_path": training_job.s3_output_path, "sagemaker.entry_point": training_job.entry_point, "sagemaker.source_dir": training_job.source_dir, } def training_run_tags(self, training_job: Any) -> dict[str, Any]: return {"sagemaker.job_name": training_job.job_name} def training_status_params(self, training_job_status: Any) -> dict[str, Any]: return { "sagemaker.training_status": training_job_status.status, "sagemaker.created_at": training_job_status.created, "sagemaker.modified_at": training_job_status.modified, "sagemaker.model_artifacts": training_job_status.model_artifacts, "sagemaker.failure_reason": training_job_status.failure_reason, } def model_version_tags(self, training_job_status: Any) -> dict[str, Any]: return {"sagemaker.job_name": training_job_status.name} def mlflow_tracking_backend_from_config(cfg: Config) -> MlflowTrackingBackend: return AwsMlflowTrackingBackend(profile=cfg.aws.profile, region=cfg.aws.region)