diff --git a/utils/loggers/wandb/wandb_utils.py b/utils/loggers/wandb/wandb_utils.py index f520fbba8850..9a80dc42ca95 100644 --- a/utils/loggers/wandb/wandb_utils.py +++ b/utils/loggers/wandb/wandb_utils.py @@ -5,8 +5,8 @@ import sys from contextlib import contextmanager from pathlib import Path -import pkg_resources as pkg +import pkg_resources as pkg import yaml from tqdm import tqdm @@ -49,9 +49,11 @@ def check_wandb_dataset(data_file): if check_file(data_file) and data_file.endswith('.yaml'): with open(data_file, errors='ignore') as f: data_dict = yaml.safe_load(f) - is_wandb_artifact = (data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX) or - data_dict['val'].startswith(WANDB_ARTIFACT_PREFIX)) - if is_wandb_artifact: + is_trainset_wandb_artifact = (isinstance(data_dict['train'], str) and + data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX)) + is_valset_wandb_artifact = (isinstance(data_dict['val'], str) and + data_dict['val'].startswith(WANDB_ARTIFACT_PREFIX)) + if is_trainset_wandb_artifact or is_valset_wandb_artifact: return data_dict else: return check_dataset(data_file)