Mlflow implementation (#2)
Reviewed-on: #2
This commit was merged in pull request #2.
This commit is contained in:
@@ -17,3 +17,20 @@ def describe_tracking_server(region: str, profile: str, name: str) -> dict[str,
|
||||
):
|
||||
return None
|
||||
raise
|
||||
|
||||
|
||||
def get_tracking_server_arn(region: str, profile: str, name: str) -> str:
|
||||
server = describe_tracking_server(region, profile, name)
|
||||
if not server:
|
||||
raise ValueError(f"MLflow tracking server not found: {name}")
|
||||
|
||||
arn = server.get("TrackingServerArn")
|
||||
if not arn:
|
||||
raise ValueError(f"MLflow tracking server has no ARN: {name}")
|
||||
return str(arn)
|
||||
|
||||
|
||||
def create_presigned_tracking_server_url(region: str, profile: str, name: str) -> str:
|
||||
client = boto3.Session(profile_name=profile, region_name=region).client("sagemaker")
|
||||
response = client.create_presigned_mlflow_tracking_server_url(TrackingServerName=name)
|
||||
return str(response["AuthorizedUrl"])
|
||||
|
||||
@@ -36,6 +36,7 @@ class TrainingJobStatus:
|
||||
modified: datetime | None
|
||||
model_artifacts: str | None
|
||||
failure_reason: str | None
|
||||
raw: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _sm(session: Boto3SessionKwargs) -> SageMakerClient:
|
||||
@@ -116,6 +117,7 @@ def get_training_job_status(session: Boto3SessionKwargs, job_name: str) -> Train
|
||||
modified=resp.get("LastModifiedTime"),
|
||||
model_artifacts=resp.get("ModelArtifacts", {}).get("S3ModelArtifacts"),
|
||||
failure_reason=resp.get("FailureReason"),
|
||||
raw=dict(resp),
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user