Skip to content

Commit

Permalink
Merge pull request #9 from nederlandsespoorwegen/update-doc-strings
Browse files Browse the repository at this point in the history
Update doc strings
  • Loading branch information
luuk-at-NS authored Feb 18, 2023
2 parents 2edde92 + 926f7ac commit 8d21b77
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 16 deletions.
117 changes: 102 additions & 15 deletions environment_mlflow_client/env_mlflow_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""mlflow client that is aware of the application environment, main target is databricks mlflow"""
"""mlflow client that is aware of the application environment. The main target is Databricks mlflow within one workspace."""
import os
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple
Expand All @@ -15,6 +15,8 @@ class EnvMlflowClient(mlflow.tracking.MlflowClient):
* load a (latest) model version
* log and register a model version and set the stage property
The environment is set by the environment variable MLFLOW_ENV or by passing the env_name argument in the init.
"""

ENVIRONMENT_KEY = "MLFLOW_ENV"
Expand All @@ -26,11 +28,17 @@ def __init__(
registry_uri: Optional[str] = None,
):
"""
For databricks we do not provide these arguments but rely on environment variables.
Create an EnvMlflowClient instance scoped for one logical environment.
The environment is set by the environment variable MLFLOW_ENV or by passing the env_name argument.
note: On Databricks we do not provide the arguments and rely on environment variables.
Args:
env_name: environment name
env_name: environment name overrides the environment variable MLFLOW_ENV
tracking_uri: Address of local or remote tracking server.
registry_uri: Address of local or remote model registry server.
Returns:
mlflow Client
"""
super().__init__(tracking_uri, registry_uri)
if env_name is None and not self.ENVIRONMENT_KEY in os.environ:
Expand All @@ -48,7 +56,13 @@ def get_env_model_name(self, name: str) -> str:
def get_latest_versions(self, name: str) -> List[ModelVersion]:
"""
Get latest model version.
The stage parameter is not supported as we set its value.
note: The stage parameter is not supported as we set its value.
Args:
name: Name of the model
Returns:
List of ModelVersion objects
"""
name = self.get_env_model_name(name)
return super().get_latest_versions(
Expand All @@ -63,6 +77,7 @@ def get_latest_model_version(self, name: str) -> ModelVersion:
name: Name of the model
Returns:
Latest ModelVersion
"""
return self.get_latest_versions(name)[0]

Expand All @@ -71,13 +86,28 @@ def set_model_version_tag(
) -> None:
"""
Set a tag on a model version
The stage parameter is not supported as we set its value.
note: The stage parameter is not supported.
Args:
name: Name of the model
version: Version of the model
key: Tag key
value: Tag value
"""
name = self.get_env_model_name(name)
super().set_model_version_tag(name, version, key, value)

def set_registered_model_tag(self, name: str, key: str, value: Any) -> None:
"""Set a tag on a registered model"""
"""
Set a tag on a registered model.
Args:
name: Name of the model
key: Tag key
value: Tag value
"""
name = self.get_env_model_name(name)
super().set_registered_model_tag(name, key, value)

Expand All @@ -91,7 +121,20 @@ def create_model_version(
description: Optional[str] = None,
await_creation_for: int = ...,
) -> ModelVersion:
"""
Create a new model version
Args:
name: Name of the model
source: Path to the model
run_id: ID of the run that created the model
tags: Tags to set on the model
run_link: Link to the run that created the model
description: Description of the model
await_creation_for: Number of seconds to wait for the model version to finish being created.
Returns:
mlflow ModelVersion
"""
name = self.get_env_model_name(name)
return super().create_model_version(
name, source, run_id, tags, run_link, description, await_creation_for
Expand All @@ -103,10 +146,22 @@ def create_registered_model(
tags: Optional[Dict[str, Any]] = None,
description: Optional[str] = None,
) -> RegisteredModel:
"""
Create a new registered model
Args:
name: Name of the model
tags: Tags to set on the model
description: Description of the model
Returns:
mlflow RegisteredModel
"""
name = self.get_env_model_name(name)
return super().create_registered_model(name, tags, description)

def get_model_version_download_uri(self, name: str, version: str) -> str:
"""Get the download URI of a model version"""
name = self.get_env_model_name(name)
return super().get_model_version_download_uri(name, version)

Expand All @@ -117,10 +172,15 @@ def get_registered_model(self, name: str) -> RegisteredModel:

def transition_model_version_stage(self, name: str, version: str) -> ModelVersion:
"""
Set the stage of a registered model.
More than one model can be in one stage.
We do not support the stage and archive_existing_versions paramters,
as we set to those.
Set the stage of a registered model. More than one model can be in one stage.
note: We do not support the stage and archive_existing_versions parameters.
Args:
name: Name of the model
version: Version of the model
Returns:
mlflow ModelVersion
"""
name = self.get_env_model_name(name)
return super().transition_model_version_stage(
Expand All @@ -138,7 +198,18 @@ def get_model_version(self, name: str, version: str) -> ModelVersion:
def load_model_version(
self, model_flavor, name: str, version: str, unwrap_model: bool = False
) -> Any:
"""Load a model version within the specified stage"""
"""
Load a model version within the specified stage
Args:
model_flavor: i.e. mlflow.pyfunc or mlflow.spark
name: Name of the model
version: Version of the model
unwrap_model: If True, return the underlying model implementation
Returns:
The loaded model
"""
model_version = self.get_model_version(
name=name,
version=version,
Expand All @@ -157,9 +228,11 @@ def load_latest_model(
Args:
model_flavor: i.e. mlflow.pyfunc or mlflow.spark
name: Name of the model
unwrap_model: If True, return the underlying model implementation
Returns:
The loaded model
"""
latest_versions = self.get_latest_versions(name)
model = model_flavor.load_model(latest_versions[0].source)
Expand All @@ -173,15 +246,15 @@ def log_model_helper(
"""
Standardize model logging setting an environment aware name
and stage attribute.
All parameters are passed to the log_model of the model_flavor.
Args:
model_flavor: i.e. mlflow.pyfunc
registered_model_name: base model name
kwargs: parameters for log_model minus registered_model_name, see documentation of model_flavor
kwargs: parameters for model_flavor.log_model, see documentation of model_flavor
i.e. mlflow.sklearn.log_model
Returns:
Tuple of a mlflow.entities.model_registry.ModelVersion object and mlflow.models.model.ModelInfo object.
"""
if "artifact_path" in kwargs:
kwargs["artifact_path"] = self.get_env_model_name(kwargs["artifact_path"])
Expand All @@ -199,11 +272,25 @@ def log_model_helper(
return model_version, model_info

def get_env_experiment_name(self, name: str) -> str:
"""Get environment specific experiment name."""
"""Get environment specific experiment name.
Args:
name: Name of the experiment
Returns:
experiment_name: Name/path of the experiment with environment prefix
"""
return f"/experiments/{self.env_name}/{name}"

def create_experiment_if_not_exists(self, name: str) -> str:
"""Create MLflow experiment if not exists."""
"""
Create MLflow experiment if not exists.
Args:
name: Name of the experiment
Returns:
experiment_id: ID of the experiment
"""
name = self.get_env_experiment_name(name)
try:
experiment_id = mlflow.create_experiment(name=name)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "environment-mlflow-client"
version = "1.1.0"
version = "1.1.1"
description = "environment scoped mlflow client"
authors = ["Team Sigma <[email protected]>"]
license = "LICENSE.txt"
Expand Down

0 comments on commit 8d21b77

Please sign in to comment.