diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e4e40f..3dd9d43 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ # Changelog ## Unreleased +- Add "--exclude-artifacts" option to sync command (mihran113) - Set mlflow experiment name as `aim.Run`'s experiment and parse the mlflow run name (mihran113) - Try to parse "string-ified" params values (sirykd) - Remove 'mlflow_run_name' and add 'mlflow_experiment_id' fields (sirykd) diff --git a/aimlflow/cli.py b/aimlflow/cli.py index 413fb8f..4a612b2 100644 --- a/aimlflow/cli.py +++ b/aimlflow/cli.py @@ -25,7 +25,9 @@ def cli_entry_point(): writable=True)) @click.option('--mlflow-tracking-uri', required=False, default=None) @click.option('--experiment', '-e', required=False, default=None) -def sync(aim_repo, mlflow_tracking_uri, experiment): +@click.option('--exclude-artifacts', multiple=True, required=False) +def sync(aim_repo, mlflow_tracking_uri, experiment, exclude_artifacts): + repo_path = clean_repo_path(aim_repo) or Repo.default_repo_path() repo_inst = Repo.from_path(repo_path) @@ -33,10 +35,10 @@ def sync(aim_repo, mlflow_tracking_uri, experiment): if not mlflow_tracking_uri: raise ClickException('MLFlow tracking URI must be provided either through ENV or CLI.') - watcher = MLFlowWatcher(repo_inst, mlflow_tracking_uri, experiment) + watcher = MLFlowWatcher(repo_inst, mlflow_tracking_uri, experiment, exclude_artifacts) click.echo('Converting existing MLflow logs.') - convert_existing_logs(repo_inst, mlflow_tracking_uri, experiment) + convert_existing_logs(repo_inst, mlflow_tracking_uri, experiment, exclude_artifacts) click.echo(f'Starting watcher on {mlflow_tracking_uri}.') watcher.start() diff --git a/aimlflow/utils.py b/aimlflow/utils.py index 1cef538..280c35b 100644 --- a/aimlflow/utils.py +++ b/aimlflow/utils.py @@ -1,3 +1,4 @@ +import fnmatch import click import collections import mlflow @@ -117,7 +118,10 @@ def collect_run_params(aim_run, mlflow_run): } -def collect_artifacts(aim_run, mlflow_run, mlflow_client): +def collect_artifacts(aim_run, mlflow_run, mlflow_client, exclude_artifacts): + if '*' in exclude_artifacts: + return + run_id = mlflow_run.info.run_id artifacts_cache_key = '_mlflow_artifacts_cache' @@ -139,6 +143,16 @@ def collect_artifacts(aim_run, mlflow_run, mlflow_client): continue else: artifacts_cache.append(file_info.path) + + if exclude_artifacts: + exclude = False + for expr in exclude_artifacts: + if fnmatch.fnmatch(file_info.path, expr): + exclude = True + break + if exclude: + continue + downloaded_path = mlflow_client.download_artifacts(run_id, file_info.path, dst_path=temp_path) if file_info.path.endswith(HTML_EXTENSIONS): if not __html_warning_issued: @@ -199,7 +213,7 @@ def collect_metrics(aim_run, mlflow_run, mlflow_client, timestamp=None): aim_run.track(m.value, step=m.step, name=m.key) -def convert_existing_logs(repo_inst, tracking_uri, experiment=None, no_cache=False): +def convert_existing_logs(repo_inst, tracking_uri, experiment=None, excluded_artifacts=None, no_cache=False): client = mlflow.tracking.client.MlflowClient(tracking_uri=tracking_uri) experiments = get_mlflow_experiments(client, experiment) @@ -217,7 +231,7 @@ def convert_existing_logs(repo_inst, tracking_uri, experiment=None, no_cache=Fal collect_metrics(aim_run, run, client) # Collect artifacts - collect_artifacts(aim_run, run, client) + collect_artifacts(aim_run, run, client, excluded_artifacts) run_cache.refresh() diff --git a/aimlflow/watcher.py b/aimlflow/watcher.py index 82eb071..53e8147 100644 --- a/aimlflow/watcher.py +++ b/aimlflow/watcher.py @@ -28,6 +28,7 @@ def __init__(self, repo: 'Repo', tracking_uri: str, experiment: str = None, + exclude_artifacts: str = None, interval: Union[int, float] = WATCH_INTERVAL_DEFAULT, ): @@ -37,6 +38,8 @@ def __init__(self, self._watch_interval = interval self._client = MlflowClient(tracking_uri) + + self._exclude_artifacts = exclude_artifacts self._experiment = experiment self._experiments = get_mlflow_experiments(self._client, self._experiment) self._repo = repo @@ -81,7 +84,7 @@ def _process_single_run(self, aim_run, mlflow_run): collect_metrics(aim_run, mlflow_run, self._client, timestamp=self._last_watch_time) # Collect artifacts - collect_artifacts(aim_run, mlflow_run, self._client) + collect_artifacts(aim_run, mlflow_run, self._client, self._exclude_artifacts) def _process_runs(self): watch_started_time = time.time()