Skip to content

Commit

Permalink
Add blit manager to manage overlay artists
Browse files Browse the repository at this point in the history
The blit manager was modified from this example: https://matplotlib.org/stable/users/explain/animations/blitting.html

It is being used to automatically manage things like updating the background
image, updating the blit artists, etc.

We are now using it to manage the overlay artists, including drawing them
and removing them.

Blitting the overlay artists provides a *significant* performance boost when
we modify parameters that only require the overlay artists to be redrawn.
Especially when we have a high resolution in any of the views (other than raw,
which doesn't have a resolution).

For example, the sliders for the Laue parameters are significantly more
interactive now. And this provides a great opportunity for us to boost
performance for the pressure sliders too.

Signed-off-by: Patrick Avery <[email protected]>
  • Loading branch information
psavery committed Dec 1, 2023
1 parent 2ea73fe commit 211ece7
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 26 deletions.
103 changes: 103 additions & 0 deletions hexrdgui/blit_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from collections.abc import Sequence, ValuesView

from matplotlib.artist import Artist


class BlitManager:
def __init__(self, canvas):
"""
Parameters
----------
canvas : FigureCanvasAgg
The canvas to work with. The background will be cached when needed.
This class was modified from here:
https://matplotlib.org/stable/users/explain/animations/blitting.html
"""
self.canvas = canvas
self.bg = None

# This dict can contain nested dicts, lists, etc.
# But all non-container values must be artists.
# We will find them recursively.
self.artists = {}

# grab the background on every draw
self.cid = canvas.mpl_connect("draw_event", self.on_draw)

def disconnect(self):
self.remove_all_artists()

if self.cid is not None:
self.mpl_disconnect(self.cid)
self.cid = None

def on_draw(self, event):
"""Callback to register with 'draw_event'."""
cv = self.canvas
if event is not None:
if event.canvas != cv:
msg = (
f'Event canvas "{event.canvas}" does not match the '
f'manager canvas "{cv}"'
)
raise RuntimeError(msg)

self.bg = cv.copy_from_bbox(cv.figure.bbox)
self.draw_all_artists()

def remove_artists(self, *path):
# The *path is an arbitrary path into the artist dict
parent = None
d = self.artists
for key in path:
if key not in d:
# It already doesn't exist. Just return.
return

parent = d
d = d[key]

for artist in _recursive_yield_artists(d):
artist.remove()

if parent:
del parent[key]
else:
self.artists.clear()

def draw_all_artists(self):
"""Draw all of the animated artists."""
fig = self.canvas.figure
for artist in _recursive_yield_artists(self.artists):
fig.draw_artist(artist)

def update(self):
"""Update the screen with animated artists."""
cv = self.canvas
fig = cv.figure

# paranoia in case we missed the draw event,
if self.bg is None:
self.on_draw(None)
else:
# restore the background
cv.restore_region(self.bg)
# draw all of the animated artists
self.draw_all_artists()
# update the GUI state
cv.blit(fig.bbox)

# let the GUI event loop process anything it has to do
cv.flush_events()


def _recursive_yield_artists(artists):
if isinstance(artists, dict):
yield from _recursive_yield_artists(artists.values())
elif isinstance(artists, (Sequence, ValuesView)):
for v in artists:
if isinstance(v, Artist):
yield v
else:
yield from _recursive_yield_artists(v)
54 changes: 28 additions & 26 deletions hexrdgui/image_canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np

from hexrdgui.async_worker import AsyncWorker
from hexrdgui.blit_manager import BlitManager
from hexrdgui.calibration.cartesian_plot import cartesian_viewer
from hexrdgui.calibration.polar_plot import polar_viewer
from hexrdgui.calibration.raw_iviewer import raw_iviewer
Expand All @@ -42,7 +43,6 @@ def __init__(self, parent=None, image_names=None):

self.raw_axes = {} # only used for raw currently
self.axes_images = []
self.overlay_artists = {}
self.cached_detector_borders = []
self.saturation_texts = []
self.cmap = HexrdConfig().default_cmap
Expand All @@ -57,6 +57,7 @@ def __init__(self, parent=None, image_names=None):
self._last_stereo_size = None
self.stereo_border_artists = []
self.azimuthal_overlay_artists = []
self.blit_manager = BlitManager(self)

# Track the current mode so that we can more lazily clear on change.
self.mode = None
Expand Down Expand Up @@ -225,24 +226,20 @@ def scaled_image_dict(self):
def scaled_images(self):
return [self.transform(x) for x in self.unscaled_images]

def remove_all_overlay_artists(self):
while self.overlay_artists:
key = next(iter(self.overlay_artists))
self.remove_overlay_artists(key)
@property
def blit_artists(self):
return self.blit_manager.artists

def remove_overlay_artists(self, key):
if key not in self.overlay_artists:
return
@property
def overlay_artists(self):
return self.blit_artists.setdefault('overlays', {})

for det_key, artist_dict in self.overlay_artists[key].items():
for artist_name, artist in artist_dict.items():
if isinstance(artist, list):
while artist:
artist.pop(0).remove()
else:
artist.remove()
def remove_all_overlay_artists(self):
self.blit_manager.remove_artists('overlays')
self.blit_manager.artists['overlays'] = {}

del self.overlay_artists[key]
def remove_overlay_artists(self, key):
self.blit_manager.remove_artists('overlays', key)

def prune_overlay_artists(self):
# Remove overlay artists that no longer have an overlay associated
Expand Down Expand Up @@ -380,7 +377,8 @@ def split(data):
def plot(data, key, kwargs):
# This logic was repeated
if len(data) != 0:
artists[key], = axis.plot(*np.vstack(data).T, **kwargs)
artists[key], = axis.plot(*np.vstack(data).T, animated=True,
**kwargs)

plot(rings, 'rings', data_style)
plot(h_rings, 'h_rings', highlight_style['data'])
Expand All @@ -404,7 +402,8 @@ def az_plot(data, key, kwargs):
x = np.repeat(xmeans, 3)
y = np.tile([0, 1, np.nan], len(xmeans))

artists[key], = az_axis.plot(x, y, transform=trans, **kwargs)
artists[key], = az_axis.plot(x, y, transform=trans,
animated=True, **kwargs)

az_plot(rings, 'az_rings', data_style)
# NOTE: we still use the data_style for az_axis highlighted rings
Expand Down Expand Up @@ -452,12 +451,14 @@ def split(data):
def scatter(data, key, kwargs):
# This logic was repeated
if len(data) != 0:
artists[key] = axis.scatter(*np.asarray(data).T, **kwargs)
artists[key] = axis.scatter(*np.asarray(data).T, animated=True,
**kwargs)

def plot(data, key, kwargs):
# This logic was repeated
if len(data) != 0:
artists[key], = axis.plot(*np.vstack(data).T, **kwargs)
artists[key], = axis.plot(*np.vstack(data).T, animated=True,
**kwargs)

# Draw spots and highlighted spots
scatter(spots, 'spots', data_style)
Expand Down Expand Up @@ -530,11 +531,12 @@ def draw_rotation_series_overlay(self, artist_key, det_key, axis, data,
overlay_artists = self.overlay_artists.setdefault(artist_key, {})
artists = overlay_artists.setdefault(det_key, {})

artists['data'] = axis.scatter(*sliced_data.T, **data_style)
artists['data'] = axis.scatter(*sliced_data.T, animated=True,
**data_style)

sliced_ranges = np.asarray(ranges)[slicer]
artists['ranges'], = axis.plot(*np.vstack(sliced_ranges).T,
**ranges_style)
animated=True, **ranges_style)

def redraw_overlay(self, overlay):
# Remove the artists for this overlay
Expand Down Expand Up @@ -573,7 +575,7 @@ def update_overlays(self):
for overlay in HexrdConfig().overlays:
self.draw_overlay(overlay)

self.draw_idle()
self.blit_manager.update()

def clear_detector_borders(self):
while self.cached_detector_borders:
Expand Down Expand Up @@ -825,9 +827,9 @@ def finish_show_cartesian(self, iviewer):
self.figure.tight_layout()

self.update_auto_picked_data()
self.update_overlays()
self.draw_detector_borders()
self.update_beam_marker()
self.update_overlays()

HexrdConfig().image_view_loaded.emit({'img': img})

Expand Down Expand Up @@ -952,9 +954,9 @@ def finish_show_polar(self, iviewer):
self.figure.tight_layout()

self.update_auto_picked_data()
self.update_overlays()
self.draw_detector_borders()
self.update_beam_marker()
self.update_overlays()

HexrdConfig().image_view_loaded.emit({'img': img})

Expand Down Expand Up @@ -1018,9 +1020,9 @@ def finish_show_stereo(self, iviewer):

self.draw_stereo_border()
self.update_auto_picked_data()
self.update_overlays()
self.draw_detector_borders()
self.update_beam_marker()
self.update_overlays()

HexrdConfig().image_view_loaded.emit({'img': img})

Expand Down

0 comments on commit 211ece7

Please sign in to comment.