Skip to content

Commit

Permalink
Merge pull request #4359 from jenshnielsen/dond_refactor
Browse files Browse the repository at this point in the history
Refactor internals of do_nd
  • Loading branch information
jenshnielsen authored Jul 5, 2022
2 parents 00b6bfa + 988650f commit 7e8d69c
Showing 1 changed file with 111 additions and 39 deletions.
150 changes: 111 additions & 39 deletions qcodes/dataset/do_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,88 @@ def post_actions(self) -> ActionsT:
return self._post_actions


class _Sweeper:
def __init__(
self,
sweeps: Sequence[AbstractSweep],
additional_setpoints: Sequence[ParameterBase],
):
self._sweeps = sweeps
self._additional_setpoints = additional_setpoints
self._nested_setpoints = self._make_nested_setpoints()

self._post_delays = tuple(sweep.delay for sweep in sweeps)
self._params_set = tuple(sweep.param for sweep in sweeps)
self._post_actions = tuple(sweep.post_actions for sweep in sweeps)

def _make_nested_setpoints(self) -> np.ndarray:
"""Create the cartesian product of all the setpoint values."""
if len(self._sweeps) == 0:
return np.array([[]]) # 0d sweep (do0d)
setpoint_values = [sweep.get_setpoints() for sweep in self._sweeps]
return self._flatten_setpoint_values(setpoint_values)

@staticmethod
def _flatten_setpoint_values(setpoint_values: Sequence[np.ndarray]) -> np.ndarray:
setpoint_grids = np.meshgrid(*setpoint_values, indexing="ij")
flat_setpoint_grids = [np.ravel(grid, order="C") for grid in setpoint_grids]
return np.vstack(flat_setpoint_grids).T

@property
def nested_setpoints(self) -> np.ndarray:
return self._nested_setpoints

@property
def all_setpoint_params(self) -> Tuple[ParameterBase, ...]:
return tuple(sweep.param for sweep in self._sweeps) + tuple(
s for s in self._additional_setpoints
)

@property
def shape(self) -> Tuple[int, ...]:
loop_shape = tuple(sweep.num_points for sweep in self._sweeps) + tuple(
1 for _ in self._additional_setpoints
)
return loop_shape

@property
def post_delays(self) -> Tuple[float, ...]:
return self._post_delays

@property
def params_set(self) -> Tuple[ParameterBase, ...]:
return self._params_set

@property
def post_actions(self) -> Tuple[ActionsT, ...]:
return self._post_actions


class _Measurements:
def __init__(
self,
measurement_name: str,
params_meas: Sequence[Union[ParamMeasT, Sequence[ParamMeasT]]],
):
(
self._measured_all,
self._grouped_parameters,
self._measured_parameters,
) = _extract_paramters_by_type_and_group(measurement_name, params_meas)

@property
def measured_all(self) -> Tuple[ParamMeasT, ...]:
return self._measured_all

@property
def grouped_parameters(self) -> Dict[str, ParameterGroup]:
return self._grouped_parameters

@property
def measured_parameters(self) -> Tuple[ParameterBase, ...]:
return self._measured_parameters


def dond(
*params: Union[AbstractSweep, Union[ParamMeasT, Sequence[ParamMeasT]]],
write_period: Optional[float] = None,
Expand Down Expand Up @@ -757,108 +839,98 @@ def dond(
show_progress = config.dataset.dond_show_progress

sweep_instances, params_meas = _parse_dond_arguments(*params)
nested_setpoints = _make_nested_setpoints(sweep_instances)

all_setpoint_params = tuple(sweep.param for sweep in sweep_instances) + tuple(
s for s in additional_setpoints
)
sweeper = _Sweeper(sweep_instances, additional_setpoints)

measurements = _Measurements(measurement_name, params_meas)

(
measured_all,
grouped_parameters,
measured_parameters,
) = _extract_paramters_by_type_and_group(measurement_name, params_meas)
LOG.info(
"Starting a doNd with scan with\n setpoints: %s,\n measuring: %s",
all_setpoint_params,
measured_all,
sweeper.all_setpoint_params,
measurements.measured_all,
)
LOG.debug(
"Measured parameters have been grouped into:\n " "%s",
{name: group["params"] for name, group in grouped_parameters.items()},
{
name: group["params"]
for name, group in measurements.grouped_parameters.items()
},
)
try:
loop_shape = tuple(sweep.num_points for sweep in sweep_instances) + tuple(
1 for _ in additional_setpoints
loop_shape = sweeper.shape
shapes: Shapes = detect_shape_of_measurement(
measurements.measured_parameters, loop_shape
)
shapes: Shapes = detect_shape_of_measurement(measured_parameters, loop_shape)
LOG.debug("Detected shapes to be %s", shapes)
except TypeError:
LOG.exception(
f"Could not detect shape of {measured_parameters} "
f"Could not detect shape of {measurements.measured_parameters} "
f"falling back to unknown shape."
)
shapes = None
meas_list = _create_measurements(
all_setpoint_params,
sweeper.all_setpoint_params,
enter_actions,
exit_actions,
exp,
grouped_parameters,
measurements.grouped_parameters,
shapes,
write_period,
log_info,
)

post_delays: List[float] = []
params_set: List[ParameterBase] = []
post_actions: List[ActionsT] = []
for sweep in sweep_instances:
post_delays.append(sweep.delay)
params_set.append(sweep.param)
post_actions.append(sweep.post_actions)

datasets = []
plots_axes = []
plots_colorbar = []
if use_threads is None:
use_threads = config.dataset.use_threads

params_meas_caller = (
ThreadPoolParamsCaller(*measured_all)
ThreadPoolParamsCaller(*measurements.measured_all)
if use_threads
else SequentialParamsCaller(*measured_all)
else SequentialParamsCaller(*measurements.measured_all)
)

try:
with _catch_interrupts() as interrupted, ExitStack() as stack, params_meas_caller as call_params_meas:
datasavers = [stack.enter_context(measure.run()) for measure in meas_list]
additional_setpoints_data = process_params_meas(additional_setpoints)
previous_setpoints = np.empty(len(sweep_instances))
for setpoints in tqdm(nested_setpoints, disable=not show_progress):
for setpoints in tqdm(sweeper.nested_setpoints, disable=not show_progress):

active_actions, delays = _select_active_actions_delays(
post_actions,
post_delays,
sweeper.post_actions,
sweeper.post_delays,
setpoints,
previous_setpoints,
)
previous_setpoints = setpoints

param_set_list = []
param_value_action_delay = zip(
params_set,
for setpoint_param, setpoint, action, delay in zip(
sweeper.params_set,
setpoints,
active_actions,
delays,
)
for setpoint_param, setpoint, action, delay in param_value_action_delay:
):
_conditional_parameter_set(setpoint_param, setpoint)
param_set_list.append((setpoint_param, setpoint))
for act in action:
act()
time.sleep(delay)

meas_value_pair = call_params_meas()
for group in grouped_parameters.values():
for group in measurements.grouped_parameters.values():
group["measured_params"] = []
for measured in meas_value_pair:
if measured[0] in group["params"]:
group["measured_params"].append(measured)
for ind, datasaver in enumerate(datasavers):
datasaver.add_result(
*param_set_list,
*grouped_parameters[f"group_{ind}"]["measured_params"],
*measurements.grouped_parameters[f"group_{ind}"][
"measured_params"
],
*additional_setpoints_data,
)

Expand All @@ -875,7 +947,7 @@ def dond(
plots_axes.append(plot_axis)
plots_colorbar.append(plot_color)

if len(grouped_parameters) == 1:
if len(measurements.grouped_parameters) == 1:
return datasets[0], plots_axes[0], plots_colorbar[0]
else:
return tuple(datasets), tuple(plots_axes), tuple(plots_colorbar)
Expand Down Expand Up @@ -910,7 +982,7 @@ def _conditional_parameter_set(
parameter.set(value)


def _make_nested_setpoints(sweeps: List[AbstractSweep]) -> np.ndarray:
def _make_nested_setpoints(sweeps: Sequence[AbstractSweep]) -> np.ndarray:
"""Create the cartesian product of all the setpoint values."""
if len(sweeps) == 0:
return np.array([[]]) # 0d sweep (do0d)
Expand Down

0 comments on commit 7e8d69c

Please sign in to comment.