Reviewed-on: #3
This commit was merged in pull request #3.
This commit is contained in:
2026-06-03 21:06:06 +00:00
parent e9ada2612f
commit a3f3060e13
16 changed files with 1161 additions and 56 deletions

View File

@@ -21,6 +21,24 @@ def upload_file(
return f"s3://{bucket}/{s3_key}"
def download_file(
region: str,
profile: str,
s3_uri: str,
local_path: str,
) -> str:
if not s3_uri.startswith("s3://"):
raise ValueError(f"Expected S3 URI, got: {s3_uri}")
bucket_key = s3_uri.removeprefix("s3://")
bucket, _, key = bucket_key.partition("/")
if not bucket or not key:
raise ValueError(f"Expected S3 URI with bucket and key, got: {s3_uri}")
dest = Path(local_path)
dest.parent.mkdir(parents=True, exist_ok=True)
_client(region, profile).download_file(bucket, key, str(dest))
return str(dest)
def upload_dir(
region: str,
profile: str,

View File

@@ -121,6 +121,16 @@ def get_training_job_status(session: Boto3SessionKwargs, job_name: str) -> Train
)
def get_model_artifacts(region: str, profile: str, job_name: str) -> str:
resp = boto3.Session(profile_name=profile, region_name=region).client("sagemaker").describe_training_job(
TrainingJobName=job_name
)
artifact = resp.get("ModelArtifacts", {}).get("S3ModelArtifacts")
if not artifact:
raise RuntimeError(f"Training job '{job_name}' does not have model artifacts yet.")
return str(artifact)
def list_training_jobs(
session: Boto3SessionKwargs,
max_results: int = 10,