Skip to content

Commit

Permalink
fixing mlflow logging to Databricks workspace file paths with /Shared…
Browse files Browse the repository at this point in the history
…/ prefix (#3410)

* fixing os file path with /Shared/ prefix

* lstrip '/' from experiment name if not '/Shared/' or '/Users/'

Co-authored-by: Mihir Patel <[email protected]>

* doesnt modify experiment name if it has '/Shared/' as a prefix

* fix formatting

* lint

---------

Co-authored-by: Mihir Patel <[email protected]>
  • Loading branch information
JackZ-db and mvpatel2000 authored Jun 18, 2024
1 parent 2587c05 commit 152e528
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions composer/loggers/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,12 @@ def __init__(
)
assert self.experiment_name is not None # type hint

if os.getenv('DATABRICKS_TOKEN') is not None and not self.experiment_name.startswith('/Users/'):
if os.getenv(
'DATABRICKS_TOKEN',
) is not None and not self.experiment_name.startswith((
'/Users/',
'/Shared/',
)):
try:
from databricks.sdk import WorkspaceClient
except ImportError as e:
Expand All @@ -160,7 +165,7 @@ def __init__(
conda_channel='conda-forge',
) from e
databricks_username = WorkspaceClient().current_user.me().user_name or ''
self.experiment_name = '/' + os.path.join('Users', databricks_username, self.experiment_name)
self.experiment_name = os.path.join('/Users', databricks_username, self.experiment_name.strip('/'))

self._mlflow_client = MlflowClient(self.tracking_uri)
# Set experiment
Expand Down

0 comments on commit 152e528

Please sign in to comment.