diff --git a/README.md b/README.md index 9c8b076..1c31319 100644 --- a/README.md +++ b/README.md @@ -93,12 +93,6 @@ mlflow: tracking_server_name: your-tracking-server-name ``` -Install the optional MLflow dependencies before enabling MLflow: - -```bash -uv sync --extra mlflow -``` - When MLflow is enabled, `train start` creates an MLflow run for the SageMaker job. `train status` finalizes that run once the job reaches a terminal state and registers completed model artifacts as pre-release model versions using the `prerelease-latest` MLflow alias. To open the managed SageMaker MLflow UI, request a fresh presigned URL: diff --git a/pyproject.toml b/pyproject.toml index de7b1fe..16d8132 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,13 +12,9 @@ dependencies = [ "typer==0.25.0", "boto3>=1.34,<1.42", "constructs>=10.0.0", + "mlflow>=3.0", "pydantic>=2.13.3", "pyyaml>=6.0.3", -] - -[project.optional-dependencies] -mlflow = [ - "mlflow>=3.0", "sagemaker-mlflow>=0.4.0", ] diff --git a/src/tracking/mlflow.py b/src/tracking/mlflow.py index 7019b65..2189eba 100644 --- a/src/tracking/mlflow.py +++ b/src/tracking/mlflow.py @@ -4,6 +4,9 @@ import os from dataclasses import dataclass from typing import Any, Protocol +import mlflow +from mlflow.tracking import MlflowClient + from src.aws import mlflow as aws_mlflow from src.config import Config, MlflowMode @@ -30,7 +33,6 @@ class NoopTracker: @dataclass(frozen=True) class MlflowTracker: - mlflow: Any tracking_uri: str experiment_name: str registered_model_name: str @@ -43,14 +45,6 @@ class MlflowTracker: os.environ.setdefault("MLFLOW_SUPPRESS_PRINTING_URL_TO_STDOUT", "true") - try: - import mlflow - except ImportError as e: - raise RuntimeError( - "MLflow is enabled in config but optional dependencies are not installed. " - "Install with: qc-cli[mlflow]" - ) from e - tracking_server_name = cfg.effective_mlflow_tracking_server_name if not tracking_server_name: raise RuntimeError("MLflow tracking server name could not be resolved.") @@ -64,7 +58,6 @@ class MlflowTracker: mlflow.set_experiment(cfg.mlflow.experiment_name) return cls( - mlflow=mlflow, tracking_uri=tracking_uri, experiment_name=cfg.mlflow.experiment_name, registered_model_name=cfg.mlflow.registered_model_name, @@ -72,7 +65,7 @@ class MlflowTracker: ) def start_training_run(self, training_job: Any, *, region: str, profile: str, role_arn: str) -> str | None: - run = self.mlflow.start_run(run_name=training_job.job_name) + run = mlflow.start_run(run_name=training_job.job_name) run_id = str(run.info.run_id) params = { @@ -90,21 +83,21 @@ class MlflowTracker: } self._log_params(params) self._log_params({f"hyperparameters.{key}": value for key, value in training_job.hyperparameters.items()}) - self.mlflow.set_tags( + mlflow.set_tags( { "qc_cli.stage": "prerelease", "qc_cli.command": "train start", "sagemaker.job_name": training_job.job_name, } ) - self.mlflow.end_run() + mlflow.end_run() return run_id def finalize_training_run(self, *, run_id: str | None, training_job_status: Any) -> str | None: if not run_id: return None - with self.mlflow.start_run(run_id=run_id): + with mlflow.start_run(run_id=run_id): self._log_params( { "sagemaker.training_status": training_job_status.status, @@ -115,16 +108,16 @@ class MlflowTracker: } ) self._log_final_metrics(training_job_status.raw) - self.mlflow.set_tag("qc_cli.command", "train status") + mlflow.set_tag("qc_cli.command", "train status") if training_job_status.status != "Completed" or not training_job_status.model_artifacts: - self.mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status) + mlflow.set_tag("qc_cli.training_terminal_status", training_job_status.status) return None if not self.register_trained_models: return None - client = self.mlflow.tracking.MlflowClient() + client = MlflowClient() self._ensure_registered_model(client, self.registered_model_name) version = client.create_model_version( name=self.registered_model_name, @@ -137,14 +130,14 @@ class MlflowTracker: ) version_number = str(version.version) self._set_alias(client, self.registered_model_name, "prerelease-latest", version_number) - self.mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name) - self.mlflow.set_tag("qc_cli.registered_model_version", version_number) + mlflow.set_tag("qc_cli.registered_model_name", self.registered_model_name) + mlflow.set_tag("qc_cli.registered_model_version", version_number) return version_number def _log_params(self, params: dict[str, Any]) -> None: cleaned = {key: str(value) for key, value in params.items() if value is not None} if cleaned: - self.mlflow.log_params(cleaned) + mlflow.log_params(cleaned) def _log_final_metrics(self, training_job: dict[str, Any]) -> None: metrics = {} @@ -154,14 +147,13 @@ class MlflowTracker: if name and value is not None: metrics[str(name)] = float(value) if metrics: - self.mlflow.log_metrics(metrics) + mlflow.log_metrics(metrics) - def _ensure_registered_model(self, client: Any, name: str) -> None: + def _ensure_registered_model(self, client: MlflowClient, name: str) -> None: try: client.get_registered_model(name) except Exception: client.create_registered_model(name) - def _set_alias(self, client: Any, name: str, alias: str, version: str) -> None: - if hasattr(client, "set_registered_model_alias"): - client.set_registered_model_alias(name, alias, version) + def _set_alias(self, client: MlflowClient, name: str, alias: str, version: str) -> None: + client.set_registered_model_alias(name, alias, version) diff --git a/uv.lock b/uv.lock index c14c12e..bcbd0d3 100644 --- a/uv.lock +++ b/uv.lock @@ -2011,15 +2011,11 @@ dependencies = [ { name = "aws-cdk-lib" }, { name = "boto3" }, { name = "constructs" }, + { name = "mlflow" }, { name = "pydantic" }, { name = "pyyaml" }, - { name = "typer" }, -] - -[package.optional-dependencies] -mlflow = [ - { name = "mlflow" }, { name = "sagemaker-mlflow" }, + { name = "typer" }, ] [package.dev-dependencies] @@ -2036,13 +2032,12 @@ requires-dist = [ { name = "aws-cdk-lib", specifier = ">=2.180.0" }, { name = "boto3", specifier = ">=1.34,<1.42" }, { name = "constructs", specifier = ">=10.0.0" }, - { name = "mlflow", marker = "extra == 'mlflow'", specifier = ">=3.0" }, + { name = "mlflow", specifier = ">=3.0" }, { name = "pydantic", specifier = ">=2.13.3" }, { name = "pyyaml", specifier = ">=6.0.3" }, - { name = "sagemaker-mlflow", marker = "extra == 'mlflow'", specifier = ">=0.4.0" }, + { name = "sagemaker-mlflow", specifier = ">=0.4.0" }, { name = "typer", specifier = "==0.25.0" }, ] -provides-extras = ["mlflow"] [package.metadata.requires-dev] dev = [