From 02719dde52a99f28603be383334028a7ab9f1e06 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 28 Jun 2021 13:48:14 +0200 Subject: [PATCH] Update `feature_visualization()` (#3807) * Update `feature_visualization()` Only plot for data with height, width > 1 * cleanup * Cleanup --- utils/plots.py | 40 +++++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/utils/plots.py b/utils/plots.py index 36386371dbec..4b6c63992ac7 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -448,26 +448,28 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''): fig.savefig(Path(save_dir) / 'results.png', dpi=200) -def feature_visualization(features, module_type, module_idx, n=64): +def feature_visualization(x, module_type, stage, n=64): """ - features: Features to be visualized + x: Features to be visualized module_type: Module type - module_idx: Module layer index within model + stage: Module stage within model n: Maximum number of feature maps to plot """ - project, name = 'runs/features', 'exp' - save_dir = increment_path(Path(project) / name) # increment run - save_dir.mkdir(parents=True, exist_ok=True) # make dir - - plt.figure(tight_layout=True) - blocks = torch.chunk(features, features.shape[1], dim=1) # block by channel dimension - n = min(n, len(blocks)) - for i in range(n): - feature = transforms.ToPILImage()(blocks[i].squeeze()) - ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1) - ax.axis('off') - plt.imshow(feature) # cmap='gray' - - f = f"layer_{module_idx}_{module_type.split('.')[-1]}_features.png" - print(f'Saving {save_dir / f}...') - plt.savefig(save_dir / f, dpi=300) + batch, channels, height, width = x.shape # batch, channels, height, width + if height > 1 and width > 1: + project, name = 'runs/features', 'exp' + save_dir = increment_path(Path(project) / name) # increment run + save_dir.mkdir(parents=True, exist_ok=True) # make dir + + plt.figure(tight_layout=True) + blocks = torch.chunk(x, channels, dim=1) # block by channel dimension + n = min(n, len(blocks)) + for i in range(n): + feature = transforms.ToPILImage()(blocks[i].squeeze()) + ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1) + ax.axis('off') + plt.imshow(feature) # cmap='gray' + + f = f"stage_{stage}_{module_type.split('.')[-1]}_features.png" + print(f'Saving {save_dir / f}...') + plt.savefig(save_dir / f, dpi=300)