include-metrics-from-training (#6)
Reviewed-on: #6
This commit was merged in pull request #6.
This commit is contained in:
77
src/cloud/mlflow.py
Normal file
77
src/cloud/mlflow.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from contextlib import AbstractContextManager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
|
||||
from src.aws import mlflow as aws_mlflow
|
||||
from src.config import Config
|
||||
|
||||
|
||||
class MlflowTrackingBackend(Protocol):
|
||||
@property
|
||||
def provider_name(self) -> str: ...
|
||||
|
||||
@property
|
||||
def profile(self) -> str: ...
|
||||
|
||||
@property
|
||||
def region(self) -> str: ...
|
||||
|
||||
def get_tracking_uri(self, tracking_server_name: str) -> str: ...
|
||||
|
||||
def auth_env(self) -> AbstractContextManager[None]: ...
|
||||
|
||||
def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]: ...
|
||||
|
||||
def training_run_tags(self, training_job: Any) -> dict[str, Any]: ...
|
||||
|
||||
def training_status_params(self, training_job_status: Any) -> dict[str, Any]: ...
|
||||
|
||||
def model_version_tags(self, training_job_status: Any) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AwsMlflowTrackingBackend:
|
||||
profile: str
|
||||
region: str
|
||||
provider_name: str = "aws"
|
||||
|
||||
def get_tracking_uri(self, tracking_server_name: str) -> str:
|
||||
return aws_mlflow.get_tracking_server_arn(self.region, self.profile, tracking_server_name)
|
||||
|
||||
def auth_env(self) -> AbstractContextManager[None]:
|
||||
return aws_mlflow.tracking_auth_env(self.profile, self.region)
|
||||
|
||||
def training_run_params(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> dict[str, Any]:
|
||||
return {
|
||||
"provider.name": self.provider_name,
|
||||
"provider.region": region,
|
||||
"provider.profile": profile,
|
||||
"sagemaker.role_arn": role_arn,
|
||||
"sagemaker.job_name": training_job.job_name,
|
||||
"sagemaker.training_image": training_job.image_uri,
|
||||
"sagemaker.instance_type": training_job.instance_type,
|
||||
"sagemaker.instance_count": training_job.instance_count,
|
||||
"sagemaker.s3_train_uri": training_job.s3_train_uri,
|
||||
"sagemaker.s3_output_path": training_job.s3_output_path,
|
||||
"sagemaker.entry_point": training_job.entry_point,
|
||||
"sagemaker.source_dir": training_job.source_dir,
|
||||
}
|
||||
|
||||
def training_run_tags(self, training_job: Any) -> dict[str, Any]:
|
||||
return {"sagemaker.job_name": training_job.job_name}
|
||||
|
||||
def training_status_params(self, training_job_status: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"sagemaker.training_status": training_job_status.status,
|
||||
"sagemaker.created_at": training_job_status.created,
|
||||
"sagemaker.modified_at": training_job_status.modified,
|
||||
"sagemaker.model_artifacts": training_job_status.model_artifacts,
|
||||
"sagemaker.failure_reason": training_job_status.failure_reason,
|
||||
}
|
||||
|
||||
def model_version_tags(self, training_job_status: Any) -> dict[str, Any]:
|
||||
return {"sagemaker.job_name": training_job_status.name}
|
||||
|
||||
|
||||
def mlflow_tracking_backend_from_config(cfg: Config) -> MlflowTrackingBackend:
|
||||
return AwsMlflowTrackingBackend(profile=cfg.aws.profile, region=cfg.aws.region)
|
||||
Reference in New Issue
Block a user