Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to stitch ROI images in raw view #1655

Merged
merged 2 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 159 additions & 0 deletions hexrdgui/calibration/raw_iviewer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from itertools import chain

import numpy as np

from hexrdgui.constants import ViewType
from hexrdgui.create_hedm_instrument import create_hedm_instrument
from hexrdgui.hexrd_config import HexrdConfig
Expand All @@ -13,6 +17,9 @@ class InstrumentViewer:
def __init__(self):
self.type = ViewType.raw
self.instr = create_hedm_instrument()
self.roi_info = {}

self.setup_roi_info()

def update_overlay_data(self):
update_overlay_data(self.instr, self.type)
Expand All @@ -26,3 +33,155 @@ def update_detector(self, det):
self.instr.detectors[det].tilt = t_conf['tilt']

# Since these are just individual images, no further updates are needed

@property
def has_roi(self):
# Assume it has ROI support if a single detector supports it
panel = next(iter(self.instr.detectors.values()))
return all(x is not None for x in (panel.roi, panel.group))

@property
def roi_groups(self):
return self.roi_info.get('groups', {})

@property
def roi_stitched_shapes(self):
return self.roi_info.get('stitched_shapes', {})

def setup_roi_info(self):
if not self.has_roi:
# Required info is missing
return

groups = {}
for det_key, panel in self.instr.detectors.items():
groups.setdefault(panel.group, []).append(det_key)

self.roi_info['groups'] = groups

# Find the ROI bounds for each group
stitched_shapes = {}
for group, det_keys in groups.items():
row_size = 0
col_size = 0
for det_key in det_keys:
panel = self.instr.detectors[det_key]
row_size = max(row_size, panel.roi[0][1])
col_size = max(col_size, panel.roi[1][1])

stitched_shapes[group] = (row_size, col_size)

self.roi_info['stitched_shapes'] = stitched_shapes

def raw_to_stitched(self, ij, det_key):
ij = np.array(ij)

panel = self.instr.detectors[det_key]

if ij.size != 0:
ij[:, 0] += panel.roi[0][0]
ij[:, 1] += panel.roi[1][0]

return ij, panel.group

def stitched_to_raw(self, ij, stitched_key):
ij = np.atleast_2d(ij)

ret = {}
for det_key in self.roi_groups[stitched_key]:
panel = self.instr.detectors[det_key]
on_panel_rows = (
in_range(ij[:, 0], panel.roi[0]) &
in_range(ij[:, 1], panel.roi[1])
)
if np.any(on_panel_rows):
new_ij = ij[on_panel_rows]
new_ij[:, 0] -= panel.roi[0][0]
new_ij[:, 1] -= panel.roi[1][0]
ret[det_key] = new_ij

return ret

def raw_images_to_stitched(self, group_names, images_dict):
shapes = self.roi_stitched_shapes
stitched = {}
for group in group_names:
for det_key in self.roi_groups[group]:
panel = self.instr.detectors[det_key]
if group not in stitched:
stitched[group] = np.empty(shapes[group])
image = stitched[group]
roi = panel.roi
image[slice(*roi[0]), slice(*roi[1])] = images_dict[det_key]

return stitched

def create_overlay_data(self, overlay):
if HexrdConfig().stitch_raw_roi_images:
return self.create_roi_overlay_data(overlay)

return overlay.data

def create_roi_overlay_data(self, overlay):
ret = {}
for det_key, data in overlay.data.items():
panel = self.instr.detectors[det_key]

def raw_to_stitched(x):
# x is in "ji" coordinates
x[:, 0] += panel.roi[1][0]
x[:, 1] += panel.roi[0][0]

group = panel.group
ret.setdefault(group, {})
for data_key, entry in data.items():
if data_key == overlay.ranges_indices_key:
# Need to adjust indices since we stack ranges
if data_key not in ret[group]:
ret[group][data_key] = entry
continue

# We need to adjust these rbnd_indices. Find the max.
prev_max = max(chain(*ret[group][data_key]))
for i, x in enumerate(entry):
entry[i] = [j + prev_max + 1 for j in x]
ret[group][data_key] += entry
continue

if data_key not in overlay.plot_data_keys:
# This is not for plotting. No conversions needed.
ret[group][data_key] = entry
continue

if len(entry) == 0:
continue

# Points are 2D in shape, and lines are 3D in shape.
# Perorm the conversion regardless of the dimensions.
if isinstance(entry, list) or entry.ndim == 3:
for x in entry:
# This will convert in-place since `x` is a view
raw_to_stitched(x)
else:
raw_to_stitched(entry)

if data_key in ret[group]:
# Stack it with previous entries
if isinstance(ret[group][data_key], list):
entry = ret[group][data_key] + entry
else:
entry = np.vstack((ret[group][data_key], entry))

ret[group][data_key] = entry

# If data was missing for a whole group, set it to an empty list.
for group in ret:
for data_key in overlay.plot_data_keys:
if data_key not in ret[group]:
ret[group][data_key] = []

return ret


def in_range(x, range):
return (range[0] <= x) & (x < range[1])
35 changes: 30 additions & 5 deletions hexrdgui/hexrd_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def __init__(self):
self.hdf5_path = []
self.live_update = True
self._show_saturation_level = False
self._stitch_raw_roi_images = False
self._tab_images = False
self.previous_active_material = None
self.collapsed_state = []
Expand Down Expand Up @@ -367,6 +368,7 @@ def _attributes_to_persist(self):
('config_calibration', None),
('config_indexing', None),
('config_image', None),
('_stitch_raw_roi_images', False),
('font_size', 11),
('images_dir', None),
('working_dir', '.'),
Expand Down Expand Up @@ -1554,11 +1556,23 @@ def default_detector(self):
self.default_config['instrument']['detectors']['detector_1'])

@property
def is_roi_instrument_config(self):
for det in HexrdConfig().detectors.values():
if det.get('pixels', {}).get('roi', {}).get('value'):
return True
return False
def instrument_has_roi(self):
det = next(iter(self.detectors.values()))

# Both the group and roi must be present to support ROI
has_group = det.get('group', {}).get('value')
has_roi = det.get('pixels', {}).get('roi', {}).get('value')
return bool(has_group and has_roi)

@property
def detector_group_names(self):
names = []
for det_key in self.detectors:
name = self.detector_group(det_key)
if name and name not in names:
names.append(name)

return names

def detector_group(self, detector_name):
det = self.detector(detector_name)
Expand Down Expand Up @@ -2383,6 +2397,17 @@ def set_show_saturation_level(self, v):
show_saturation_level = property(get_show_saturation_level,
set_show_saturation_level)

def get_stitch_raw_roi_images(self):
return self.instrument_has_roi and self._stitch_raw_roi_images

def set_stitch_raw_roi_images(self, v):
if self._stitch_raw_roi_images != v:
self._stitch_raw_roi_images = v
self.deep_rerender_needed.emit()

stitch_raw_roi_images = property(get_stitch_raw_roi_images,
set_stitch_raw_roi_images)

def tab_images(self):
return self._tab_images

Expand Down
62 changes: 44 additions & 18 deletions hexrdgui/image_canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(self, parent=None, image_names=None):
self.stereo_border_artists = []
self.azimuthal_overlay_artists = []
self.blit_manager = BlitManager(self)
self.raw_view_images_dict = {}

# Track the current mode so that we can more lazily clear on change.
self.mode = None
Expand Down Expand Up @@ -115,6 +116,7 @@ def __del__(self):
def clear(self):
self.iviewer = None
self.mode = None
self.raw_view_images_dict = {}
self.clear_figure()

def clear_figure(self):
Expand Down Expand Up @@ -155,6 +157,11 @@ def load_images(self, image_names):
# This will be used for drawing the rings
self.iviewer = raw_iviewer()

if HexrdConfig().stitch_raw_roi_images:
# The image_names is actually a list of group names
images_dict = self.iviewer.raw_images_to_stitched(
image_names, images_dict)

cols = 1
if len(image_names) > 1:
cols = 2
Expand All @@ -178,17 +185,21 @@ def load_images(self, image_names):
self.figure.tight_layout()
else:
images_dict = self.scaled_image_dict
if HexrdConfig().stitch_raw_roi_images:
# The image_names is actually a list of group names
images_dict = self.iviewer.raw_images_to_stitched(
image_names, images_dict)

for i, name in enumerate(image_names):
img = images_dict[name]
self.axes_images[i].set_data(img)

self.raw_view_images_dict = images_dict

# This will call self.draw_idle()
self.show_saturation()

self.update_beam_marker()

# Set the detectors to draw
self.iviewer.detectors = list(self.raw_axes)
self.update_auto_picked_data()
self.update_overlays()

Expand Down Expand Up @@ -261,14 +272,15 @@ def overlay_axes_data(self, overlay):
return []

if self.mode == ViewType.raw:
data = self.iviewer.create_overlay_data(overlay)

# If it's raw, there is data for each axis.
# The title of each axis should match the detector key.
# Add a safety check to ensure everything is synced up.
if not all(x in overlay.data for x in self.raw_axes):
if not all(x in data for x in self.raw_axes):
return []

return [(self.raw_axes[x], x, overlay.data[x])
for x in self.raw_axes]
return [(self.raw_axes[x], x, data[x]) for x in self.raw_axes]

# If it is anything else, there is only one axis
# Use the same axis for all of the data
Expand Down Expand Up @@ -373,8 +385,8 @@ def split(data):
if not found:
# Not highlighted or merged
reg_ranges.append(ranges[i])
else:
found = False

found = False

def plot(data, key, kwargs):
# This logic was repeated
Expand Down Expand Up @@ -542,8 +554,7 @@ def draw_rotation_series_overlay(self, artist_key, det_key, axis, data,

def draw_const_chi_overlay(self, artist_key, det_key, axis, data, style,
highlight_style):
points = [x['data'] for x in data]

points = data['data']
data_style = style['data']

overlay_artists = self.overlay_artists.setdefault(artist_key, {})
Expand Down Expand Up @@ -698,21 +709,36 @@ def show_saturation(self):
# Use the unscaled image data to determine saturation
images_dict = self.unscaled_image_dict

def compute_saturation_and_size(detector_name):
detector = HexrdConfig().detector(detector_name)
saturation_level = detector['saturation_level']['value']

array = images_dict[detector_name]

num_sat = (array >= saturation_level).sum()

return num_sat, array.size

for img in self.axes_images:
# The titles of the images are currently the detector names
# If we change this in the future, we will need to change
# our method for getting the saturation level as well.
ax = img.axes
detector_name = ax.get_title()
detector = HexrdConfig().detector(detector_name)
saturation_level = detector['saturation_level']['value']

array = images_dict[detector_name]
axes_title = ax.get_title()

num_sat = (array >= saturation_level).sum()
percent = num_sat / array.size * 100.0
str_sat = 'Saturation: ' + str(num_sat)
str_sat += '\n%5.3f %%' % percent
if HexrdConfig().stitch_raw_roi_images:
# The axes title is the group name
det_keys = self.iviewer.roi_groups[axes_title]
results = [compute_saturation_and_size(x) for x in det_keys]
num_sat = sum(x[0] for x in results)
size = sum(x[1] for x in results)
else:
# The axes_title is the detector name
num_sat, size = compute_saturation_and_size(axes_title)

percent = num_sat / size * 100.0
str_sat = f'Saturation: {num_sat}\n%{percent:5.3f}'

t = ax.text(0.05, 0.05, str_sat, fontdict={'color': 'w'},
transform=ax.transAxes)
Expand Down
2 changes: 1 addition & 1 deletion hexrdgui/image_load_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def thread_pool(self):
@property
def naming_options(self):
dets = HexrdConfig().detector_names
if HexrdConfig().is_roi_instrument_config:
if HexrdConfig().instrument_has_roi:
groups = [HexrdConfig().detector_group(d) for d in dets]
dets.extend([g for g in groups if g is not None])
return dets
Expand Down
Loading
Loading