From 27a02f202ef8504c1a85251fa0d54f62b288a0c4 Mon Sep 17 00:00:00 2001 From: ayulockin Date: Wed, 12 Jan 2022 13:40:16 +0000 Subject: [PATCH 1/4] wandb integration --- config/sample_ddpm_128.json | 3 ++ config/sample_sr3_128.json | 3 ++ config/sr_ddpm_16_128.json | 3 ++ config/sr_sr3_16_128.json | 9 ++-- config/sr_sr3_64_512.json | 3 ++ core/logger.py | 8 ++++ core/wandb_logger.py | 88 +++++++++++++++++++++++++++++++++++++ sr.py | 59 +++++++++++++++++++++++-- 8 files changed, 169 insertions(+), 7 deletions(-) create mode 100644 core/wandb_logger.py diff --git a/config/sample_ddpm_128.json b/config/sample_ddpm_128.json index 0d8230dee..12288e7ef 100755 --- a/config/sample_ddpm_128.json +++ b/config/sample_ddpm_128.json @@ -90,5 +90,8 @@ "update_ema_every": 1, "ema_decay": 0.9999 } + }, + "wandb": { + "project": "generation_ffhq_ddpm" } } \ No newline at end of file diff --git a/config/sample_sr3_128.json b/config/sample_sr3_128.json index 06b9076c8..20f2225d9 100755 --- a/config/sample_sr3_128.json +++ b/config/sample_sr3_128.json @@ -89,5 +89,8 @@ "update_ema_every": 1, "ema_decay": 0.9999 } + }, + "wandb": { + "project": "generation_ffhq_sr3" } } \ No newline at end of file diff --git a/config/sr_ddpm_16_128.json b/config/sr_ddpm_16_128.json index 1b8be7326..02c955e02 100755 --- a/config/sr_ddpm_16_128.json +++ b/config/sr_ddpm_16_128.json @@ -90,5 +90,8 @@ "update_ema_every": 1, "ema_decay": 0.9999 } + }, + "wandb": { + "project": "sr_ffhq" } } \ No newline at end of file diff --git a/config/sr_sr3_16_128.json b/config/sr_sr3_16_128.json index f8a4bc2c7..c092095c1 100755 --- a/config/sr_sr3_16_128.json +++ b/config/sr_sr3_16_128.json @@ -17,7 +17,7 @@ "name": "FFHQ", "mode": "HR", // whether need LR img "dataroot": "dataset/ffhq_16_128", - "datatype": "lmdb", //lmdb or img, path of img files + "datatype": "img", //lmdb or img, path of img files "l_resolution": 16, // low resolution need to super_resolution "r_resolution": 128, // high resolution "batch_size": 4, @@ -28,8 +28,8 @@ "val": { "name": "CelebaHQ", "mode": "LRHR", - "dataroot": "dataset/celebahq_16_128", - "datatype": "lmdb", //lmdb or img, path of img files + "dataroot": "dataset/ffhq_16_128", + "datatype": "img", //lmdb or img, path of img files "l_resolution": 16, "r_resolution": 128, "data_len": 50 // data length in validation @@ -89,5 +89,8 @@ "update_ema_every": 1, "ema_decay": 0.9999 } + }, + "wandb": { + "project": "sr_ffhq" } } \ No newline at end of file diff --git a/config/sr_sr3_64_512.json b/config/sr_sr3_64_512.json index c05bb65dc..853bcfcd9 100755 --- a/config/sr_sr3_64_512.json +++ b/config/sr_sr3_64_512.json @@ -92,5 +92,8 @@ "update_ema_every": 1, "ema_decay": 0.9999 } + }, + "wandb": { + "project": "distributed_high_sr_ffhq" } } \ No newline at end of file diff --git a/core/logger.py b/core/logger.py index fac3eb995..ea1159187 100644 --- a/core/logger.py +++ b/core/logger.py @@ -22,6 +22,9 @@ def parse(args): phase = args.phase opt_path = args.config gpu_ids = args.gpu_ids + enable_wandb = args.enable_wandb + log_wandb_ckpt = args.log_wandb_ckpt + log_eval = args.log_eval # remove comments starting with '//' json_str = '' with open(opt_path, 'r') as f: @@ -72,6 +75,11 @@ def parse(args): if phase == 'train': opt['datasets']['val']['data_len'] = 3 + # W&B Logging + opt['enable_wandb'] = enable_wandb + opt['log_wandb_ckpt'] = log_wandb_ckpt + opt['log_eval'] = log_eval + return opt diff --git a/core/wandb_logger.py b/core/wandb_logger.py new file mode 100644 index 000000000..71d7da283 --- /dev/null +++ b/core/wandb_logger.py @@ -0,0 +1,88 @@ +import os + +class WandbLogger: + """ + Log using `Weights and Biases`. + """ + def __init__(self, opt): + try: + import wandb + except ImportError: + raise ImportError( + "To use the Weights and Biases Logger please install wandb." + "Run `pip install wandb` to install it." + ) + + self._wandb = wandb + + # Initialize a W&B run + if self._wandb.run is None: + self._wandb.init( + project=opt['wandb']['project'], + config=opt, + dir='./experiments' + ) + + self.config = self._wandb.config + + if self.config['log_eval']: + self.eval_table = self._wandb.Table(columns=['fake_image', + 'sr_image', + 'hr_image', + 'psnr', + 'ssim']) + + def log_metrics(self, metrics, commit=True): + """ + Log train/validation metrics onto W&B. + + metrics: dictionary of metrics to be logged + """ + self._wandb.log(metrics, commit=commit) + + def log_image(self, key_name, image_array): + """ + Log image array onto W&B. + + key_name: name of the key + image_array: numpy array of image. + """ + self._wandb.log({key_name: self._wandb.Image(image_array)}) + + def log_checkpoint(self, current_epoch, current_step): + """ + Log the model checkpoint as W&B artifacts + + current_epoch: the current epoch + current_step: the current batch step + """ + model_artifact = self._wandb.Artifact( + self._wandb.run.id + "_model", type="model" + ) + + gen_path = os.path.join( + self.config.path['checkpoint'], 'I{}_E{}_gen.pth'.format(current_step, current_epoch)) + opt_path = os.path.join( + self.config.path['checkpoint'], 'I{}_E{}_opt.pth'.format(current_step, current_epoch)) + + model_artifact.add_file(gen_path) + model_artifact.add_file(opt_path) + self._wandb.log_artifact(model_artifact, aliases=["latest"]) + + def log_eval_data(self, fake_img, sr_img, hr_img, psnr, ssim): + """ + Add data row-wise to the initialized table. + """ + self.eval_table.add_data( + self._wandb.Image(fake_img), + self._wandb.Image(sr_img), + self._wandb.Image(hr_img), + psnr, + ssim + ) + + def log_eval_table(self, commit=False): + """ + Log the table + """ + self._wandb.log({'eval_data': self.eval_table}, commit=commit) diff --git a/sr.py b/sr.py index f78144350..c8d0f7b75 100755 --- a/sr.py +++ b/sr.py @@ -5,9 +5,11 @@ import logging import core.logger as Logger import core.metrics as Metrics +from core.wandb_logger import WandbLogger from tensorboardX import SummaryWriter import os import numpy as np +import wandb if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -17,6 +19,9 @@ help='Run either train(training) or val(generation)', default='train') parser.add_argument('-gpu', '--gpu_ids', type=str, default=None) parser.add_argument('-debug', '-d', action='store_true') + parser.add_argument('-enable_wandb', action='store_true') + parser.add_argument('-log_wandb_ckpt', action='store_true') + parser.add_argument('-log_eval', action='store_true') # parse configs args = parser.parse_args() @@ -35,6 +40,16 @@ logger.info(Logger.dict2str(opt)) tb_logger = SummaryWriter(log_dir=opt['path']['tb_logger']) + # Initialize WandbLogger + if opt['enable_wandb']: + wandb_logger = WandbLogger(opt) + wandb.define_metric('validation/val_step') + wandb.define_metric('epoch') + wandb.define_metric("validation/*", step_metric="val_step") + val_step = 0 + else: + wandb_logger = None + # dataset for phase, dataset_opt in opt['datasets'].items(): if phase == 'train' and args.phase != 'val': @@ -81,6 +96,9 @@ tb_logger.add_scalar(k, v, current_step) logger.info(message) + if wandb_logger: + wandb_logger.log_metrics(logs) + # validation if current_step % opt['train']['val_freq'] == 0: avg_psnr = 0.0 @@ -118,6 +136,12 @@ avg_psnr += Metrics.calculate_psnr( sr_img, hr_img) + if wandb_logger: + wandb_logger.log_image( + f'validation_{idx}', + np.concatenate((fake_img, sr_img, hr_img), axis=1) + ) + avg_psnr = avg_psnr / idx diffusion.set_new_noise_schedule( opt['model']['beta_schedule']['train'], schedule_phase='train') @@ -129,9 +153,23 @@ # tensorboard logger tb_logger.add_scalar('psnr', avg_psnr, current_step) + if wandb_logger: + wandb_logger.log_metrics({ + 'validation/val_psnr': avg_psnr, + 'validation/val_step': val_step + }) + val_step += 1 + if current_step % opt['train']['save_checkpoint_freq'] == 0: logger.info('Saving models and training states.') diffusion.save_network(current_epoch, current_step) + + if wandb_logger and opt['log_wandb_ckpt']: + wandb_logger.log_checkpoint(current_epoch, current_step) + + if wandb_logger: + wandb_logger.log_metrics({'epoch': current_epoch-1}) + # save model logger.info('End of training.') else: @@ -175,10 +213,15 @@ fake_img, '{}/{}_{}_inf.png'.format(result_path, current_step, idx)) # generation - avg_psnr += Metrics.calculate_psnr( - Metrics.tensor2img(visuals['SR'][-1]), hr_img) - avg_ssim += Metrics.calculate_ssim( - Metrics.tensor2img(visuals['SR'][-1]), hr_img) + eval_psnr = Metrics.calculate_psnr(Metrics.tensor2img(visuals['SR'][-1]), hr_img) + eval_ssim = Metrics.calculate_ssim(Metrics.tensor2img(visuals['SR'][-1]), hr_img) + + avg_psnr += eval_psnr + avg_ssim += eval_ssim + + if wandb_logger and opt['log_eval']: + wandb_logger.log_eval_data(fake_img, Metrics.tensor2img(visuals['SR'][-1]), hr_img, eval_psnr, eval_ssim) + avg_psnr = avg_psnr / idx avg_ssim = avg_ssim / idx @@ -188,3 +231,11 @@ logger_val = logging.getLogger('val') # validation logger logger_val.info(' psnr: {:.4e}, ssim:{:.4e}'.format( current_epoch, current_step, avg_psnr, avg_ssim)) + + if wandb_logger: + if opt['log_eval']: + wandb_logger.log_eval_table() + wandb_logger.log_metrics({ + 'PSNR': float(avg_psnr), + 'SSIM': float(avg_ssim) + }) From e62c21f0746e99ee91dd17df349dbef4f6ce8d83 Mon Sep 17 00:00:00 2001 From: ayulockin Date: Wed, 12 Jan 2022 14:25:16 +0000 Subject: [PATCH 2/4] wandb integrate sample.py --- config/sr_sr3_16_128.json | 6 +++--- sample.py | 26 ++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/config/sr_sr3_16_128.json b/config/sr_sr3_16_128.json index c092095c1..6672510f9 100755 --- a/config/sr_sr3_16_128.json +++ b/config/sr_sr3_16_128.json @@ -17,7 +17,7 @@ "name": "FFHQ", "mode": "HR", // whether need LR img "dataroot": "dataset/ffhq_16_128", - "datatype": "img", //lmdb or img, path of img files + "datatype": "lmdb", //lmdb or img, path of img files "l_resolution": 16, // low resolution need to super_resolution "r_resolution": 128, // high resolution "batch_size": 4, @@ -28,8 +28,8 @@ "val": { "name": "CelebaHQ", "mode": "LRHR", - "dataroot": "dataset/ffhq_16_128", - "datatype": "img", //lmdb or img, path of img files + "dataroot": "dataset/celebahq_16_128", + "datatype": "lmdb", //lmdb or img, path of img files "l_resolution": 16, "r_resolution": 128, "data_len": 50 // data length in validation diff --git a/sample.py b/sample.py index b63da3e6d..56f820040 100755 --- a/sample.py +++ b/sample.py @@ -5,9 +5,11 @@ import logging import core.logger as Logger import core.metrics as Metrics +from core.wandb_logger import WandbLogger from tensorboardX import SummaryWriter import os import numpy as np +import wandb if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -17,6 +19,9 @@ help='Run either train(training) or val(generation)', default='train') parser.add_argument('-gpu', '--gpu_ids', type=str, default=None) parser.add_argument('-debug', '-d', action='store_true') + parser.add_argument('-enable_wandb', action='store_true') + parser.add_argument('-log_wandb_ckpt', action='store_true') + parser.add_argument('-log_eval', action='store_true') # parse configs args = parser.parse_args() @@ -35,6 +40,16 @@ logger.info(Logger.dict2str(opt)) tb_logger = SummaryWriter(log_dir=opt['path']['tb_logger']) + # Initialize WandbLogger + if opt['enable_wandb']: + wandb_logger = WandbLogger(opt) + # wandb.define_metric('validation/val_step') + # wandb.define_metric('epoch') + # wandb.define_metric("validation/*", step_metric="val_step") + val_step = 0 + else: + wandb_logger = None + # dataset for phase, dataset_opt in opt['datasets'].items(): if phase == 'train' and args.phase != 'val': @@ -78,6 +93,9 @@ tb_logger.add_scalar(k, v, current_step) logger.info(message) + if wandb_logger: + wandb_logger.log_metrics(logs) + # validation if current_step % opt['train']['val_freq'] == 0: result_path = '{}/{}'.format(opt['path'] @@ -100,12 +118,20 @@ 'Iter_{}'.format(current_step), np.transpose(sample_img, [2, 0, 1]), idx) + + if wandb_logger: + wandb_logger.log_image(f'validation_{idx}', sample_img) + diffusion.set_new_noise_schedule( opt['model']['beta_schedule']['train'], schedule_phase='train') if current_step % opt['train']['save_checkpoint_freq'] == 0: logger.info('Saving models and training states.') diffusion.save_network(current_epoch, current_step) + + if wandb_logger and opt['log_wandb_ckpt']: + wandb_logger.log_checkpoint(current_epoch, current_step) + # save model logger.info('End of training.') else: From ac2b2226edcb0cda2617af184f0584f948c1e54a Mon Sep 17 00:00:00 2001 From: ayulockin Date: Wed, 12 Jan 2022 17:13:26 +0000 Subject: [PATCH 3/4] wandb integrate infer --- core/logger.py | 21 ++++++++++++++----- core/wandb_logger.py | 48 +++++++++++++++++++++++++++++++++++--------- infer.py | 16 +++++++++++++++ sample.py | 7 ++++++- 4 files changed, 76 insertions(+), 16 deletions(-) diff --git a/core/logger.py b/core/logger.py index ea1159187..e0634ca7a 100644 --- a/core/logger.py +++ b/core/logger.py @@ -23,8 +23,6 @@ def parse(args): opt_path = args.config gpu_ids = args.gpu_ids enable_wandb = args.enable_wandb - log_wandb_ckpt = args.log_wandb_ckpt - log_eval = args.log_eval # remove comments starting with '//' json_str = '' with open(opt_path, 'r') as f: @@ -76,10 +74,23 @@ def parse(args): opt['datasets']['val']['data_len'] = 3 # W&B Logging + try: + log_wandb_ckpt = args.log_wandb_ckpt + opt['log_wandb_ckpt'] = log_wandb_ckpt + except: + pass + try: + log_eval = args.log_eval + opt['log_eval'] = log_eval + except: + pass + try: + log_infer = args.log_infer + opt['log_infer'] = log_infer + except: + pass opt['enable_wandb'] = enable_wandb - opt['log_wandb_ckpt'] = log_wandb_ckpt - opt['log_eval'] = log_eval - + return opt diff --git a/core/wandb_logger.py b/core/wandb_logger.py index 71d7da283..20f8d4f0d 100644 --- a/core/wandb_logger.py +++ b/core/wandb_logger.py @@ -25,12 +25,21 @@ def __init__(self, opt): self.config = self._wandb.config - if self.config['log_eval']: + if self.config.get('log_eval', None): self.eval_table = self._wandb.Table(columns=['fake_image', 'sr_image', 'hr_image', 'psnr', 'ssim']) + else: + self.eval_table = None + + if self.config.get('log_infer', None): + self.infer_table = self._wandb.Table(columns=['fake_image', + 'sr_image', + 'hr_image']) + else: + self.infer_table = None def log_metrics(self, metrics, commit=True): """ @@ -49,6 +58,15 @@ def log_image(self, key_name, image_array): """ self._wandb.log({key_name: self._wandb.Image(image_array)}) + def log_images(self, key_name, list_images): + """ + Log list of image array onto W&B + + key_name: name of the key + list_images: list of numpy image arrays + """ + self._wandb.log({key_name: [self._wandb.Image(img) for img in list_images]}) + def log_checkpoint(self, current_epoch, current_step): """ Log the model checkpoint as W&B artifacts @@ -69,20 +87,30 @@ def log_checkpoint(self, current_epoch, current_step): model_artifact.add_file(opt_path) self._wandb.log_artifact(model_artifact, aliases=["latest"]) - def log_eval_data(self, fake_img, sr_img, hr_img, psnr, ssim): + def log_eval_data(self, fake_img, sr_img, hr_img, psnr=None, ssim=None): """ Add data row-wise to the initialized table. """ - self.eval_table.add_data( - self._wandb.Image(fake_img), - self._wandb.Image(sr_img), - self._wandb.Image(hr_img), - psnr, - ssim - ) + if psnr is not None and ssim is not None: + self.eval_table.add_data( + self._wandb.Image(fake_img), + self._wandb.Image(sr_img), + self._wandb.Image(hr_img), + psnr, + ssim + ) + else: + self.infer_table.add_data( + self._wandb.Image(fake_img), + self._wandb.Image(sr_img), + self._wandb.Image(hr_img) + ) def log_eval_table(self, commit=False): """ Log the table """ - self._wandb.log({'eval_data': self.eval_table}, commit=commit) + if self.eval_table: + self._wandb.log({'eval_data': self.eval_table}, commit=commit) + elif self.infer_table: + self._wandb.log({'infer_data': self.infer_table}, commit=commit) diff --git a/infer.py b/infer.py index ff6151dde..a8201fea0 100755 --- a/infer.py +++ b/infer.py @@ -5,9 +5,11 @@ import logging import core.logger as Logger import core.metrics as Metrics +from core.wandb_logger import WandbLogger from tensorboardX import SummaryWriter import os import numpy as np +import wandb if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -16,6 +18,8 @@ parser.add_argument('-p', '--phase', type=str, choices=['val'], help='val(generation)', default='val') parser.add_argument('-gpu', '--gpu_ids', type=str, default=None) parser.add_argument('-debug', '-d', action='store_true') + parser.add_argument('-enable_wandb', action='store_true') + parser.add_argument('-log_infer', action='store_true') # parse configs args = parser.parse_args() @@ -34,6 +38,12 @@ logger.info(Logger.dict2str(opt)) tb_logger = SummaryWriter(log_dir=opt['path']['tb_logger']) + # Initialize WandbLogger + if opt['enable_wandb']: + wandb_logger = WandbLogger(opt) + else: + wandb_logger = None + # dataset for phase, dataset_opt in opt['datasets'].items(): if phase == 'val': @@ -85,3 +95,9 @@ hr_img, '{}/{}_{}_hr.png'.format(result_path, current_step, idx)) Metrics.save_img( fake_img, '{}/{}_{}_inf.png'.format(result_path, current_step, idx)) + + if wandb_logger and opt['log_infer']: + wandb_logger.log_eval_data(fake_img, Metrics.tensor2img(visuals['SR'][-1]), hr_img) + + if wandb_logger and opt['log_infer']: + wandb_logger.log_eval_table(commit=True) diff --git a/sample.py b/sample.py index 56f820040..bca845e34 100755 --- a/sample.py +++ b/sample.py @@ -21,7 +21,6 @@ parser.add_argument('-debug', '-d', action='store_true') parser.add_argument('-enable_wandb', action='store_true') parser.add_argument('-log_wandb_ckpt', action='store_true') - parser.add_argument('-log_eval', action='store_true') # parse configs args = parser.parse_args() @@ -139,6 +138,7 @@ result_path = '{}'.format(opt['path']['results']) os.makedirs(result_path, exist_ok=True) + sample_imgs = [] for idx in range(sample_sum): idx += 1 diffusion.sample(continous=True) @@ -159,3 +159,8 @@ sample_img, '{}/{}_{}_sample_process.png'.format(result_path, current_step, idx)) Metrics.save_img( Metrics.tensor2img(visuals['SAM'][-1]), '{}/{}_{}_sample.png'.format(result_path, current_step, idx)) + + sample_imgs.append(Metrics.tensor2img(visuals['SAM'][-1])) + + if wandb_logger: + wandb_logger.log_images('eval_images', sample_imgs) From 18e12f20f927aabe4bdbfae7949762577e1350bd Mon Sep 17 00:00:00 2001 From: ayulockin Date: Wed, 12 Jan 2022 19:44:37 +0000 Subject: [PATCH 4/4] minor --- sample.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/sample.py b/sample.py index bca845e34..c3fc2236d 100755 --- a/sample.py +++ b/sample.py @@ -42,9 +42,6 @@ # Initialize WandbLogger if opt['enable_wandb']: wandb_logger = WandbLogger(opt) - # wandb.define_metric('validation/val_step') - # wandb.define_metric('epoch') - # wandb.define_metric("validation/*", step_metric="val_step") val_step = 0 else: wandb_logger = None