78 lines
3.0 KiB
Python
78 lines
3.0 KiB
Python
from contextlib import AbstractContextManager
|
|
from dataclasses import dataclass
|
|
from typing import Any, Protocol
|
|
|
|
from src.aws import mlflow as aws_mlflow
|
|
from src.config import Config
|
|
|
|
|
|
class MlflowTrackingBackend(Protocol):
|
|
@property
|
|
def provider_name(self) -> str: ...
|
|
|
|
@property
|
|
def profile(self) -> str: ...
|
|
|
|
@property
|
|
def region(self) -> str: ...
|
|
|
|
def get_tracking_uri(self, tracking_server_name: str) -> str: ...
|
|
|
|
def auth_env(self) -> AbstractContextManager[None]: ...
|
|
|
|
def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]: ...
|
|
|
|
def training_run_tags(self, training_job: Any) -> dict[str, Any]: ...
|
|
|
|
def training_status_params(self, training_job_status: Any) -> dict[str, Any]: ...
|
|
|
|
def model_version_tags(self, training_job_status: Any) -> dict[str, Any]: ...
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class AwsMlflowTrackingBackend:
|
|
profile: str
|
|
region: str
|
|
provider_name: str = "aws"
|
|
|
|
def get_tracking_uri(self, tracking_server_name: str) -> str:
|
|
return aws_mlflow.get_tracking_server_arn(self.region, self.profile, tracking_server_name)
|
|
|
|
def auth_env(self) -> AbstractContextManager[None]:
|
|
return aws_mlflow.tracking_auth_env(self.profile, self.region)
|
|
|
|
def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]:
|
|
return {
|
|
"provider.name": self.provider_name,
|
|
"provider.region": region,
|
|
"provider.profile": profile,
|
|
"sagemaker.role_arn": role_arn,
|
|
"sagemaker.job_name": training_job.job_name,
|
|
"sagemaker.training_image": training_job.image_uri,
|
|
"sagemaker.instance_type": training_job.instance_type,
|
|
"sagemaker.instance_count": training_job.instance_count,
|
|
"sagemaker.s3_train_uri": training_job.s3_train_uri,
|
|
"sagemaker.s3_output_path": training_job.s3_output_path,
|
|
"sagemaker.entry_point": training_job.entry_point,
|
|
"sagemaker.source_dir": training_job.source_dir,
|
|
}
|
|
|
|
def training_run_tags(self, training_job: Any) -> dict[str, Any]:
|
|
return {"sagemaker.job_name": training_job.job_name}
|
|
|
|
def training_status_params(self, training_job_status: Any) -> dict[str, Any]:
|
|
return {
|
|
"sagemaker.training_status": training_job_status.status,
|
|
"sagemaker.created_at": training_job_status.created,
|
|
"sagemaker.modified_at": training_job_status.modified,
|
|
"sagemaker.model_artifacts": training_job_status.model_artifacts,
|
|
"sagemaker.failure_reason": training_job_status.failure_reason,
|
|
}
|
|
|
|
def model_version_tags(self, training_job_status: Any) -> dict[str, Any]:
|
|
return {"sagemaker.job_name": training_job_status.name}
|
|
|
|
|
|
def mlflow_tracking_backend_from_config(cfg: Config) -> MlflowTrackingBackend:
|
|
return AwsMlflowTrackingBackend(profile=cfg.aws.profile, region=cfg.aws.region)
|