Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor overlay drawing to use fewer artists #1614

Merged
merged 4 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
302 changes: 185 additions & 117 deletions hexrdgui/image_canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from PySide6.QtWidgets import QFileDialog, QMessageBox

from matplotlib.backends.backend_qt5agg import FigureCanvas

from matplotlib.figure import Figure
from matplotlib.lines import Line2D
from matplotlib.patches import Circle
from matplotlib.ticker import AutoLocator, FuncFormatter

import matplotlib.pyplot as plt
import matplotlib.transforms as tx

import numpy as np

Expand All @@ -23,6 +24,7 @@
from hexrdgui.hexrd_config import HexrdConfig
from hexrdgui.snip_viewer_dialog import SnipViewerDialog
from hexrdgui import utils
from hexrdgui.utils.array import split_array
from hexrdgui.utils.conversions import (
angles_to_stereo, cart_to_angles, cart_to_pixels, q_to_tth, tth_to_q,
)
Expand Down Expand Up @@ -232,9 +234,14 @@ def remove_overlay_artists(self, key):
if key not in self.overlay_artists:
return

artists = self.overlay_artists[key]
while artists:
artists.pop(0).remove()
for det_key, artist_dict in self.overlay_artists[key].items():
for artist_name, artist in artist_dict.items():
if isinstance(artist, list):
while artist:
artist.pop(0).remove()
else:
artist.remove()

del self.overlay_artists[key]

def prune_overlay_artists(self):
Expand Down Expand Up @@ -262,11 +269,12 @@ def overlay_axes_data(self, overlay):
if not all(x in overlay.data for x in self.raw_axes):
return []

return [(self.raw_axes[x], overlay.data[x]) for x in self.raw_axes]
return [(self.raw_axes[x], x, overlay.data[x])
for x in self.raw_axes]

# If it is anything else, there is only one axis
# Use the same axis for all of the data
return [(self.axis, x) for x in overlay.data.values()]
return [(self.axis, k, v) for k, v in overlay.data.items()]

def overlay_draw_func(self, type):
overlay_funcs = {
Expand Down Expand Up @@ -306,84 +314,111 @@ def draw_overlay(self, overlay):
type = overlay.type
style = overlay.style
highlight_style = overlay.highlight_style
for axis, data in self.overlay_axes_data(overlay):
for axis, det_key, data in self.overlay_axes_data(overlay):
kwargs = {
'artist_key': overlay.name,
'det_key': det_key,
'axis': axis,
'data': data,
'style': style,
'highlight_style': highlight_style,
}
self.overlay_draw_func(type)(**kwargs)

def draw_powder_overlay(self, artist_key, axis, data, style,
def draw_powder_overlay(self, artist_key, det_key, axis, data, style,
highlight_style):
rings = data['rings']
rbnds = data['rbnds']
ranges = data['rbnds']
rbnd_indices = data['rbnd_indices']

data_style = style['data']
ranges_style = style['ranges']

highlight_indices = [i for i, x in enumerate(rings)
if id(x) in self.overlay_highlight_ids]

artists = self.overlay_artists.setdefault(artist_key, [])
for i, pr in enumerate(rings):
current_style = data_style
if i in highlight_indices:
# Override with highlight style
current_style = highlight_style['data']

x, y = pr.T
artist, = axis.plot(x, y, **current_style)
artists.append(artist)

# Add the rbnds too
for ind, pr in zip(rbnd_indices, rbnds):
x, y = pr.T
current_style = copy.deepcopy(ranges_style)
if any(x in highlight_indices for x in ind):
# Override with highlight style
current_style = highlight_style['ranges']
elif len(ind) > 1:
# If ranges are combined, override the color to red
current_style['c'] = 'r'
artist, = axis.plot(x, y, **current_style)
artists.append(artist)

if self.azimuthal_integral_axis is not None:
az_axis = self.azimuthal_integral_axis
for pr in rings:
x = pr[:, 0]
if len(x) == 0:
# Skip over rings that are out of bounds
continue

# Average the points together for the vertical line
x = np.nanmean(x)
artist = az_axis.axvline(x, **data_style)
artists.append(artist)

# Add the rbnds too
for ind, pr in zip(rbnd_indices, rbnds):
x = pr[:, 0]
if len(x) == 0:
# Skip over rbnds that are out of bounds
continue

# Average the points together for the vertical line
x = np.nanmean(x)

current_style = copy.deepcopy(ranges_style)
if len(ind) > 1:
# If rbnds are combined, override the color to red
current_style['c'] = 'r'

artist = az_axis.axvline(x, **current_style)
artists.append(artist)

def draw_laue_overlay(self, artist_key, axis, data, style,
merged_ranges_style = copy.deepcopy(ranges_style)
merged_ranges_style['c'] = 'r'

overlay_artists = self.overlay_artists.setdefault(artist_key, {})
artists = overlay_artists.setdefault(det_key, {})

highlight_indices = []

if self.overlay_highlight_ids:
# Split up highlighted and non-highlighted components for all
highlight_indices = [i for i, x in enumerate(rings)
if id(x) in self.overlay_highlight_ids]

def split(data):
if not highlight_indices or len(data) == 0:
return [], data

return split_array(data, highlight_indices)

h_rings, rings = split(rings)

# Find merged ranges and highlighted ranges
# Some ranges will be both "merged" and "highlighted"
merged_ranges = []
h_ranges = []
reg_ranges = []

found = False
for i, ind in enumerate(rbnd_indices):
if len(ind) > 1:
merged_ranges.append(ranges[i])
found = True

if highlight_indices and any(x in highlight_indices for x in ind):
h_ranges.append(ranges[i])
found = True

if not found:
# Not highlighted or merged
reg_ranges.append(ranges[i])
else:
found = False

def plot(data, key, kwargs):
# This logic was repeated
if len(data) != 0:
artists[key], = axis.plot(*np.vstack(data).T, **kwargs)

plot(rings, 'rings', data_style)
plot(h_rings, 'h_rings', highlight_style['data'])

plot(reg_ranges, 'ranges', ranges_style)
plot(merged_ranges, 'merged_ranges', merged_ranges_style)
# Highlighting goes after merged ranges to get precedence
plot(h_ranges, 'h_ranges', highlight_style['ranges'])

az_axis = self.azimuthal_integral_axis
if az_axis:
trans = tx.blended_transform_factory(az_axis.transData,
az_axis.transAxes)

def az_plot(data, key, kwargs):
if len(data) == 0:
return

xmeans = np.array([np.nanmean(x[:, 0]) for x in data])

x = np.repeat(xmeans, 3)
y = np.tile([0, 1, np.nan], len(xmeans))

artists[key], = az_axis.plot(x, y, transform=trans, **kwargs)

az_plot(rings, 'az_rings', data_style)
# NOTE: we still use the data_style for az_axis highlighted rings
az_plot(h_rings, 'az_h_rings', data_style)

az_plot(reg_ranges, 'az_ranges', ranges_style)

# NOTE: we still use the ranges_style for az_axis highlighted rings
az_plot(h_ranges, 'az_h_ranges', ranges_style)

# Give merged ranges style precedence
az_plot(merged_ranges, 'az_merged_ranges', merged_ranges_style)

def draw_laue_overlay(self, artist_key, det_key, axis, data, style,
highlight_style):
spots = data['spots']
ranges = data['ranges']
Expand All @@ -394,79 +429,112 @@ def draw_laue_overlay(self, artist_key, axis, data, style,
ranges_style = style['ranges']
label_style = style['labels']

highlight_indices = [i for i, x in enumerate(spots)
if id(x) in self.overlay_highlight_ids]
highlight_indices = []

if self.overlay_highlight_ids:
# Split up highlighted and non-highlighted components for all
highlight_indices = [i for i, x in enumerate(spots)
if id(x) in self.overlay_highlight_ids]

def split(data):
if not highlight_indices or len(data) == 0:
return [], data

return split_array(data, highlight_indices)

h_spots, spots = split(spots)
h_ranges, ranges = split(ranges)
h_labels, labels = split(labels)

artists = self.overlay_artists.setdefault(artist_key, [])
for i, (x, y) in enumerate(spots):
current_style = data_style
if i in highlight_indices:
current_style = highlight_style['data']
overlay_artists = self.overlay_artists.setdefault(artist_key, {})
artists = overlay_artists.setdefault(det_key, {})

artist = axis.scatter(x, y, **current_style)
artists.append(artist)
def scatter(data, key, kwargs):
# This logic was repeated
if len(data) != 0:
artists[key] = axis.scatter(*np.asarray(data).T, **kwargs)

if labels:
current_label_style = label_style
if i in highlight_indices:
current_label_style = highlight_style['labels']
def plot(data, key, kwargs):
# This logic was repeated
if len(data) != 0:
artists[key], = axis.plot(*np.vstack(data).T, **kwargs)

# Draw spots and highlighted spots
scatter(spots, 'spots', data_style)
scatter(h_spots, 'h_spots', highlight_style['data'])

# Draw ranges and highlighted ranges
plot(ranges, 'ranges', ranges_style)
plot(h_ranges, 'h_ranges', highlight_style['ranges'])

# Draw labels and highlighted labels
if len(labels) or len(h_labels):
def plot_label(x, y, label, style):
kwargs = {
'x': x + label_offsets[0],
'y': y + label_offsets[1],
's': labels[i],
's': label,
'clip_on': True,
**current_label_style,
**style,
}
artist = axis.text(**kwargs)
artists.append(artist)

for i, box in enumerate(ranges):
current_style = ranges_style
if i in highlight_indices:
current_style = highlight_style['ranges']

x, y = zip(*box)
artist, = axis.plot(x, y, **current_style)
artists.append(artist)

def draw_rotation_series_overlay(self, artist_key, axis, data, style,
highlight_style):
return axis.text(**kwargs)

# I don't know of a way to use a single artist for all labels.
# FIXME: figure out how to make this faster, if needed.
artists.setdefault('labels', [])
for label, (x, y) in zip(labels, spots):
artists['labels'].append(plot_label(x, y, label, label_style))

# I don't know of a way to use a single artist for all labels.
# FIXME: figure out how to make this faster, if needed.
artists.setdefault('h_labels', [])
style = highlight_style['labels']
for label, (x, y) in zip(h_labels, h_spots):
artists['h_labels'].append(plot_label(x, y, label, style))

def draw_rotation_series_overlay(self, artist_key, det_key, axis, data,
style, highlight_style):
is_aggregated = HexrdConfig().is_aggregated
ome_range = HexrdConfig().omega_ranges
aggregated = data['aggregated'] or is_aggregated or ome_range is None
if not aggregated:
ome_width = data['omega_width']
ome_mean = np.mean(ome_range)
full_range = (ome_mean - ome_width / 2, ome_mean + ome_width / 2)

def in_range(x):
return aggregated or full_range[0] <= x <= full_range[1]

# Compute the indices that are in range for the current omega value
ome_points = data['omegas']
indices_in_range = [i for i, x in enumerate(ome_points) if in_range(x)]

if aggregated:
# This means we will keep all
slicer = slice(None)
else:
ome_width = data['omega_width']
ome_mean = np.mean(ome_range)
ome_min = ome_mean - ome_width / 2
ome_max = ome_mean + ome_width / 2

in_range = np.logical_and(ome_min <= ome_points,
ome_points <= ome_max)
slicer = np.where(in_range)

data_points = data['data']
ranges = data['ranges']

data_style = style['data']
ranges_style = style['ranges']

artists = self.overlay_artists.setdefault(artist_key, [])
for i in indices_in_range:
# data
x, y = data_points[i]
artist = axis.scatter(x, y, **data_style)
artists.append(artist)
if len(data_points) == 0:
return

# ranges
if i >= len(ranges):
continue
sliced_data = data_points[slicer]
if len(sliced_data) == 0:
return

overlay_artists = self.overlay_artists.setdefault(artist_key, {})
artists = overlay_artists.setdefault(det_key, {})

artists['data'] = axis.scatter(*sliced_data.T, **data_style)

x, y = zip(*ranges[i])
artist, = axis.plot(x, y, **ranges_style)
artists.append(artist)
sliced_ranges = np.asarray(ranges)[slicer]
artists['ranges'], = axis.plot(*np.vstack(sliced_ranges).T,
**ranges_style)

def redraw_overlay(self, overlay):
# Remove the artists for this overlay
Expand Down
Loading
Loading