Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add --exclude-artifacts option to exclude artifacts based on glob expression #12

Merged
merged 5 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Changelog

## Unreleased
- Add "--exclude-artifacts" option to sync command (mihran113)
- Set mlflow experiment name as `aim.Run`'s experiment and parse the mlflow run name (mihran113)
- Try to parse "string-ified" params values (sirykd)
- Remove 'mlflow_run_name' and add 'mlflow_experiment_id' fields (sirykd)
Expand Down
8 changes: 5 additions & 3 deletions aimlflow/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,20 @@ def cli_entry_point():
writable=True))
@click.option('--mlflow-tracking-uri', required=False, default=None)
@click.option('--experiment', '-e', required=False, default=None)
def sync(aim_repo, mlflow_tracking_uri, experiment):
@click.option('--exclude-artifacts', multiple=True, required=False)
def sync(aim_repo, mlflow_tracking_uri, experiment, exclude_artifacts):

repo_path = clean_repo_path(aim_repo) or Repo.default_repo_path()
repo_inst = Repo.from_path(repo_path)

mlflow_tracking_uri = mlflow_tracking_uri or os.environ.get('MLFLOW_TRACKING_URI')
if not mlflow_tracking_uri:
raise ClickException('MLFlow tracking URI must be provided either through ENV or CLI.')

watcher = MLFlowWatcher(repo_inst, mlflow_tracking_uri, experiment)
watcher = MLFlowWatcher(repo_inst, mlflow_tracking_uri, experiment, exclude_artifacts)

click.echo('Converting existing MLflow logs.')
convert_existing_logs(repo_inst, mlflow_tracking_uri, experiment)
convert_existing_logs(repo_inst, mlflow_tracking_uri, experiment, exclude_artifacts)

click.echo(f'Starting watcher on {mlflow_tracking_uri}.')
watcher.start()
Expand Down
20 changes: 17 additions & 3 deletions aimlflow/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import fnmatch
import click
import collections
import mlflow
Expand Down Expand Up @@ -117,7 +118,10 @@ def collect_run_params(aim_run, mlflow_run):
}


def collect_artifacts(aim_run, mlflow_run, mlflow_client):
def collect_artifacts(aim_run, mlflow_run, mlflow_client, exclude_artifacts):
if '*' in exclude_artifacts:
return

run_id = mlflow_run.info.run_id

artifacts_cache_key = '_mlflow_artifacts_cache'
Expand All @@ -139,6 +143,16 @@ def collect_artifacts(aim_run, mlflow_run, mlflow_client):
continue
else:
artifacts_cache.append(file_info.path)

if exclude_artifacts:
exclude = False
for expr in exclude_artifacts:
if fnmatch.fnmatch(file_info.path, expr):
exclude = True
break
if exclude:
continue

downloaded_path = mlflow_client.download_artifacts(run_id, file_info.path, dst_path=temp_path)
if file_info.path.endswith(HTML_EXTENSIONS):
if not __html_warning_issued:
Expand Down Expand Up @@ -199,7 +213,7 @@ def collect_metrics(aim_run, mlflow_run, mlflow_client, timestamp=None):
aim_run.track(m.value, step=m.step, name=m.key)


def convert_existing_logs(repo_inst, tracking_uri, experiment=None, no_cache=False):
def convert_existing_logs(repo_inst, tracking_uri, experiment=None, excluded_artifacts=None, no_cache=False):
client = mlflow.tracking.client.MlflowClient(tracking_uri=tracking_uri)

experiments = get_mlflow_experiments(client, experiment)
Expand All @@ -217,7 +231,7 @@ def convert_existing_logs(repo_inst, tracking_uri, experiment=None, no_cache=Fal
collect_metrics(aim_run, run, client)

# Collect artifacts
collect_artifacts(aim_run, run, client)
collect_artifacts(aim_run, run, client, excluded_artifacts)

run_cache.refresh()

Expand Down
5 changes: 4 additions & 1 deletion aimlflow/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self,
repo: 'Repo',
tracking_uri: str,
experiment: str = None,
exclude_artifacts: str = None,
interval: Union[int, float] = WATCH_INTERVAL_DEFAULT,
):

Expand All @@ -37,6 +38,8 @@ def __init__(self,
self._watch_interval = interval

self._client = MlflowClient(tracking_uri)

self._exclude_artifacts = exclude_artifacts
self._experiment = experiment
self._experiments = get_mlflow_experiments(self._client, self._experiment)
self._repo = repo
Expand Down Expand Up @@ -81,7 +84,7 @@ def _process_single_run(self, aim_run, mlflow_run):
collect_metrics(aim_run, mlflow_run, self._client, timestamp=self._last_watch_time)

# Collect artifacts
collect_artifacts(aim_run, mlflow_run, self._client)
collect_artifacts(aim_run, mlflow_run, self._client, self._exclude_artifacts)

def _process_runs(self):
watch_started_time = time.time()
Expand Down