Mlflow implementation (#2)

Reviewed-on: #2
This commit was merged in pull request #2.
This commit is contained in:
2026-06-02 19:04:23 +00:00
parent 6ac9702dc5
commit e9ada2612f
13 changed files with 2287 additions and 38 deletions

View File

@@ -17,3 +17,20 @@ def describe_tracking_server(region: str, profile: str, name: str) -> dict[str,
):
return None
raise
def get_tracking_server_arn(region: str, profile: str, name: str) -> str:
server = describe_tracking_server(region, profile, name)
if not server:
raise ValueError(f"MLflow tracking server not found: {name}")
arn = server.get("TrackingServerArn")
if not arn:
raise ValueError(f"MLflow tracking server has no ARN: {name}")
return str(arn)
def create_presigned_tracking_server_url(region: str, profile: str, name: str) -> str:
client = boto3.Session(profile_name=profile, region_name=region).client("sagemaker")
response = client.create_presigned_mlflow_tracking_server_url(TrackingServerName=name)
return str(response["AuthorizedUrl"])

View File

@@ -36,6 +36,7 @@ class TrainingJobStatus:
modified: datetime | None
model_artifacts: str | None
failure_reason: str | None
raw: dict[str, Any] = field(default_factory=dict)
def _sm(session: Boto3SessionKwargs) -> SageMakerClient:
@@ -116,6 +117,7 @@ def get_training_job_status(session: Boto3SessionKwargs, job_name: str) -> Train
modified=resp.get("LastModifiedTime"),
model_artifacts=resp.get("ModelArtifacts", {}).get("S3ModelArtifacts"),
failure_reason=resp.get("FailureReason"),
raw=dict(resp),
)