diff --git a/hexrdgui/blit_manager.py b/hexrdgui/blit_manager.py new file mode 100644 index 000000000..0add673c2 --- /dev/null +++ b/hexrdgui/blit_manager.py @@ -0,0 +1,103 @@ +from collections.abc import Sequence, ValuesView + +from matplotlib.artist import Artist + + +class BlitManager: + def __init__(self, canvas): + """ + Parameters + ---------- + canvas : FigureCanvasAgg + The canvas to work with. The background will be cached when needed. + + This class was modified from here: + https://matplotlib.org/stable/users/explain/animations/blitting.html + """ + self.canvas = canvas + self.bg = None + + # This dict can contain nested dicts, lists, etc. + # But all non-container values must be artists. + # We will find them recursively. + self.artists = {} + + # grab the background on every draw + self.cid = canvas.mpl_connect("draw_event", self.on_draw) + + def disconnect(self): + self.remove_all_artists() + + if self.cid is not None: + self.mpl_disconnect(self.cid) + self.cid = None + + def on_draw(self, event): + """Callback to register with 'draw_event'.""" + cv = self.canvas + if event is not None: + if event.canvas != cv: + msg = ( + f'Event canvas "{event.canvas}" does not match the ' + f'manager canvas "{cv}"' + ) + raise RuntimeError(msg) + + self.bg = cv.copy_from_bbox(cv.figure.bbox) + self.draw_all_artists() + + def remove_artists(self, *path): + # The *path is an arbitrary path into the artist dict + parent = None + d = self.artists + for key in path: + if key not in d: + # It already doesn't exist. Just return. + return + + parent = d + d = d[key] + + for artist in _recursive_yield_artists(d): + artist.remove() + + if parent: + del parent[key] + else: + self.artists.clear() + + def draw_all_artists(self): + """Draw all of the animated artists.""" + fig = self.canvas.figure + for artist in _recursive_yield_artists(self.artists): + fig.draw_artist(artist) + + def update(self): + """Update the screen with animated artists.""" + cv = self.canvas + fig = cv.figure + + # paranoia in case we missed the draw event, + if self.bg is None: + self.on_draw(None) + else: + # restore the background + cv.restore_region(self.bg) + # draw all of the animated artists + self.draw_all_artists() + # update the GUI state + cv.blit(fig.bbox) + + # let the GUI event loop process anything it has to do + cv.flush_events() + + +def _recursive_yield_artists(artists): + if isinstance(artists, dict): + yield from _recursive_yield_artists(artists.values()) + elif isinstance(artists, (Sequence, ValuesView)): + for v in artists: + if isinstance(v, Artist): + yield v + else: + yield from _recursive_yield_artists(v) diff --git a/hexrdgui/image_canvas.py b/hexrdgui/image_canvas.py index 9e852d038..21a30347e 100644 --- a/hexrdgui/image_canvas.py +++ b/hexrdgui/image_canvas.py @@ -16,6 +16,7 @@ import numpy as np from hexrdgui.async_worker import AsyncWorker +from hexrdgui.blit_manager import BlitManager from hexrdgui.calibration.cartesian_plot import cartesian_viewer from hexrdgui.calibration.polar_plot import polar_viewer from hexrdgui.calibration.raw_iviewer import raw_iviewer @@ -42,7 +43,6 @@ def __init__(self, parent=None, image_names=None): self.raw_axes = {} # only used for raw currently self.axes_images = [] - self.overlay_artists = {} self.cached_detector_borders = [] self.saturation_texts = [] self.cmap = HexrdConfig().default_cmap @@ -57,6 +57,7 @@ def __init__(self, parent=None, image_names=None): self._last_stereo_size = None self.stereo_border_artists = [] self.azimuthal_overlay_artists = [] + self.blit_manager = BlitManager(self) # Track the current mode so that we can more lazily clear on change. self.mode = None @@ -225,24 +226,20 @@ def scaled_image_dict(self): def scaled_images(self): return [self.transform(x) for x in self.unscaled_images] - def remove_all_overlay_artists(self): - while self.overlay_artists: - key = next(iter(self.overlay_artists)) - self.remove_overlay_artists(key) + @property + def blit_artists(self): + return self.blit_manager.artists - def remove_overlay_artists(self, key): - if key not in self.overlay_artists: - return + @property + def overlay_artists(self): + return self.blit_artists.setdefault('overlays', {}) - for det_key, artist_dict in self.overlay_artists[key].items(): - for artist_name, artist in artist_dict.items(): - if isinstance(artist, list): - while artist: - artist.pop(0).remove() - else: - artist.remove() + def remove_all_overlay_artists(self): + self.blit_manager.remove_artists('overlays') + self.blit_manager.artists['overlays'] = {} - del self.overlay_artists[key] + def remove_overlay_artists(self, key): + self.blit_manager.remove_artists('overlays', key) def prune_overlay_artists(self): # Remove overlay artists that no longer have an overlay associated @@ -380,7 +377,8 @@ def split(data): def plot(data, key, kwargs): # This logic was repeated if len(data) != 0: - artists[key], = axis.plot(*np.vstack(data).T, **kwargs) + artists[key], = axis.plot(*np.vstack(data).T, animated=True, + **kwargs) plot(rings, 'rings', data_style) plot(h_rings, 'h_rings', highlight_style['data']) @@ -404,7 +402,8 @@ def az_plot(data, key, kwargs): x = np.repeat(xmeans, 3) y = np.tile([0, 1, np.nan], len(xmeans)) - artists[key], = az_axis.plot(x, y, transform=trans, **kwargs) + artists[key], = az_axis.plot(x, y, transform=trans, + animated=True, **kwargs) az_plot(rings, 'az_rings', data_style) # NOTE: we still use the data_style for az_axis highlighted rings @@ -452,12 +451,14 @@ def split(data): def scatter(data, key, kwargs): # This logic was repeated if len(data) != 0: - artists[key] = axis.scatter(*np.asarray(data).T, **kwargs) + artists[key] = axis.scatter(*np.asarray(data).T, animated=True, + **kwargs) def plot(data, key, kwargs): # This logic was repeated if len(data) != 0: - artists[key], = axis.plot(*np.vstack(data).T, **kwargs) + artists[key], = axis.plot(*np.vstack(data).T, animated=True, + **kwargs) # Draw spots and highlighted spots scatter(spots, 'spots', data_style) @@ -530,11 +531,12 @@ def draw_rotation_series_overlay(self, artist_key, det_key, axis, data, overlay_artists = self.overlay_artists.setdefault(artist_key, {}) artists = overlay_artists.setdefault(det_key, {}) - artists['data'] = axis.scatter(*sliced_data.T, **data_style) + artists['data'] = axis.scatter(*sliced_data.T, animated=True, + **data_style) sliced_ranges = np.asarray(ranges)[slicer] artists['ranges'], = axis.plot(*np.vstack(sliced_ranges).T, - **ranges_style) + animated=True, **ranges_style) def redraw_overlay(self, overlay): # Remove the artists for this overlay @@ -573,7 +575,7 @@ def update_overlays(self): for overlay in HexrdConfig().overlays: self.draw_overlay(overlay) - self.draw_idle() + self.blit_manager.update() def clear_detector_borders(self): while self.cached_detector_borders: @@ -825,9 +827,9 @@ def finish_show_cartesian(self, iviewer): self.figure.tight_layout() self.update_auto_picked_data() - self.update_overlays() self.draw_detector_borders() self.update_beam_marker() + self.update_overlays() HexrdConfig().image_view_loaded.emit({'img': img}) @@ -952,9 +954,9 @@ def finish_show_polar(self, iviewer): self.figure.tight_layout() self.update_auto_picked_data() - self.update_overlays() self.draw_detector_borders() self.update_beam_marker() + self.update_overlays() HexrdConfig().image_view_loaded.emit({'img': img}) @@ -1018,9 +1020,9 @@ def finish_show_stereo(self, iviewer): self.draw_stereo_border() self.update_auto_picked_data() - self.update_overlays() self.draw_detector_borders() self.update_beam_marker() + self.update_overlays() HexrdConfig().image_view_loaded.emit({'img': img})