@@ -21,6 +21,24 @@ def upload_file(
|
||||
return f"s3://{bucket}/{s3_key}"
|
||||
|
||||
|
||||
def download_file(
|
||||
region: str,
|
||||
profile: str,
|
||||
s3_uri: str,
|
||||
local_path: str,
|
||||
) -> str:
|
||||
if not s3_uri.startswith("s3://"):
|
||||
raise ValueError(f"Expected S3 URI, got: {s3_uri}")
|
||||
bucket_key = s3_uri.removeprefix("s3://")
|
||||
bucket, _, key = bucket_key.partition("/")
|
||||
if not bucket or not key:
|
||||
raise ValueError(f"Expected S3 URI with bucket and key, got: {s3_uri}")
|
||||
dest = Path(local_path)
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
_client(region, profile).download_file(bucket, key, str(dest))
|
||||
return str(dest)
|
||||
|
||||
|
||||
def upload_dir(
|
||||
region: str,
|
||||
profile: str,
|
||||
|
||||
@@ -121,6 +121,16 @@ def get_training_job_status(session: Boto3SessionKwargs, job_name: str) -> Train
|
||||
)
|
||||
|
||||
|
||||
def get_model_artifacts(region: str, profile: str, job_name: str) -> str:
|
||||
resp = boto3.Session(profile_name=profile, region_name=region).client("sagemaker").describe_training_job(
|
||||
TrainingJobName=job_name
|
||||
)
|
||||
artifact = resp.get("ModelArtifacts", {}).get("S3ModelArtifacts")
|
||||
if not artifact:
|
||||
raise RuntimeError(f"Training job '{job_name}' does not have model artifacts yet.")
|
||||
return str(artifact)
|
||||
|
||||
|
||||
def list_training_jobs(
|
||||
session: Boto3SessionKwargs,
|
||||
max_results: int = 10,
|
||||
|
||||
Reference in New Issue
Block a user