diff --git a/twop/drift_correction.py b/twop/drift_correction.py new file mode 100644 index 0000000..8b97a6d --- /dev/null +++ b/twop/drift_correction.py @@ -0,0 +1,251 @@ +from multiprocessing import Process, Queue +import numpy as np +from lightparam.param_qt import ParametrizedQt +from lightparam import Param +from skimage.feature import register_translation +import flammkuchen as fl +from matplotlib import Path +from queue import Empty +from dataclasses import dataclass +from time import sleep +from scipy.ndimage.filters import gaussian_filter +from twop.objective_motor_sliders import MovementType +from scipy.signal import convolve + + +class ReferenceSettings(ParametrizedQt): + def __init__(self): + super().__init__() + self.name = "reference" + self.n_frames_ref = Param(10, (1, 500)) + self.extra_planes = Param(11, (1, 500)) + self.dz = Param(1.0, (0.1, 20.0), unit="um") + self.xy_th = Param(5.0, (0.1, 20.0), unit="um") + self.z_th = Param(self.dz, (self.dz, self.dz * 4), unit="um") + self.n_frames_exp = Param(5, (1, 500)) + self.size_k = Param(0, (0, 100)) + self.sigma_k = Param(2, (1, 10)) + + + +@dataclass +class ReferenceParameters: + n_frames_ref: int = 10 + extra_planes: int = 10 + dz: float = 1.0 + xy_th: float = 5 + z_th: float = 1 + n_frames_exp: int = 5 + size_k: int = 0 + sigma_k: int = 5 + + +def convert_reference_params(st: ReferenceSettings) -> ReferenceParameters: + n_frames_ref = st.n_frames_ref + extra_planes = st.extra_planes + xy_th = st.xy_th + dz = st.dz + z_th = st.z_th + n_frames_exp = st.n_frames_exp + sigma_k = st.sigma_k + size_k = st.size_k + rp = ReferenceParameters(n_frames_ref=n_frames_ref, + extra_planes=extra_planes, + dz=dz, + xy_th=xy_th, + z_th=z_th, + n_frames_exp=n_frames_exp, + sigma_k=sigma_k, + size_k=size_k) + + return rp + + +class Corrector(Process): + def __init__(self, reference_event, experiment_start_event, stop_event, correction_event, + reference_queue, scanning_parameters, + scanning_parameters_queue, data_queue, + input_commands_queues, output_positions_queues, save_param): + super().__init__() + # communication with other processes, active during acquisition of the reference + self.reference_event = reference_event + # communication with other processes, active during acquisition of the reference and experiment + # (all the processes use it) + self.experiment_start_event = experiment_start_event + # communication with other processes, the status is not modified by any process + self.stop_event = stop_event + # communication with other processes, active during experiment (when correction is allowed) + self.correction_event = correction_event + + # queue for the acquisition of the reference, planes are sent by the saver + self.reference_queue = reference_queue + # queue in order to know the latest settings selected by the user such as n_planes, n_frames etc. + self.reference_param_queue = Queue() + # queue in order to know the latest scanning settings, for n_x and n_y + self.scanning_parameters_queue = scanning_parameters_queue + # initial scanning parameters + self.scanning_parameters = scanning_parameters + # queue for getting the copy of the frames during an experiment + self.data_queue = data_queue + # queue for the communication with the master motor class (in a separate process), this is for move the motors + self.input_commands_queues = input_commands_queues + # queue for the communication with the master motor class (in a separate process), + # this is for read the last position + self.output_positions_queues = output_positions_queues + # to know the directory where the anatomy/reference will be saved + self.save_parameters = save_param + + self.x_pos = None + self.y_pos = None + self.z_pos = None + self.mov_type = MovementType(False) + + self.reference = None + self.reference_params = None + self.calibration_vector = None + + def run(self): + while True: + self.update_settings() + if self.reference_event.is_set() and self.experiment_start_event.is_set(): + self.start_ref_acquisition() + self.reference_loop() + self.reference_event.clear() + + elif not self.reference_event.is_set() and self.experiment_start_event.is_set(): + self.correction_event.set() + self.exp_loop() + self.correction_event.clear() + + def reference_loop(self): + stack_4d = self.get_next_entry(self.reference_queue) + self.save_reference(stack_4d) + self.reference = self.reference_processing(stack_4d) + print(self.reference.shape) + self.end_ref_acquisition() + + def compute_registration(self, test_image): + vectors = [] + errors = [] + planes = np.size(self.reference, 0) + for i in range(planes): + ref_im = np.squeeze(self.reference[i, :, :]) + output = register_translation(ref_im, test_image) + vectors.append(output[0]) + errors.append(output[1]) + ind = errors.index(min(errors)) + z_disp = ind - ((self.reference_params.n_planes - 1) / 2) + vector = vectors[ind] + np.append(vector, z_disp) + vector = self.real_units(vector) + return vector + + def exp_loop(self): + pix_millimeter = self.calculate_fov() + self.calibration_vector = [pix_millimeter, pix_millimeter, self.reference_params.dz] # x,y,z cal vect + while not self.stop_event.is_set(): + number_of_frames = 0 + frame_container = [] + while number_of_frames == self.reference_params.n_frames_exp: + try: + frame = self.data_queue.get(timeout=0.001) + frame_container.append(frame) + number_of_frames += 1 + except Empty: + frame_container = frame_container[-self.reference_params.n_frames_exp:] + frame = self.frame_processing(frame_container) + vector = self.compute_registration(frame) + self.apply_correction(vector) + + def start_ref_acquisition(self): + self.x_pos = self.get_last_entry(self.output_positions_queues["x"]) + self.y_pos = self.get_last_entry(self.output_positions_queues["y"]) + self.z_pos = self.get_last_entry(self.output_positions_queues["z"]) + # self.reference_params.n_planes = self.reference_params.n_planes + 1 + up_planes = self.reference_params.extra_planes + distance = (self.reference_params.dz / 1000) * up_planes + self.input_commands_queues["z"].put((distance, self.mov_type)) + sleep(0.2) + + def end_ref_acquisition(self): + self.input_commands_queues["z"].put((self.z_pos, self.mov_type)) + + def real_units(self, raw_vector): + vector = np.multiply(raw_vector, self.calibration_vector) + return vector + + def apply_correction(self, vector): + self.input_commands_queues["x"].put((vector[1], self.mov_type)) + self.input_commands_queues["y"].put((vector[0], self.mov_type)) + self.input_commands_queues["z"].put((vector[2], self.mov_type)) + + def reference_processing(self, input_ref): + output_ref = np.mean(input_ref, axis=0) + size_kernel = self.reference_params.size_k + sigma = self.reference_params.sigma_k + if size_kernel != 0: + kernel = self.gaussian_kernel((size_kernel,)*3, sigma=sigma) + output_ref = convolve(output_ref, kernel, mode='same') + return output_ref + + @staticmethod + def frame_processing(frame_container): + frame_container_array = np.array(frame_container) + frame = np.mean(frame_container_array, 0) + return frame + + @staticmethod + def get_last_entry(queue): + out = None + while True: + try: + out = queue.get(timeout=0.001) + except Empty: + break + return out + + @staticmethod + def get_next_entry(queue): + out = None + while out is None: + try: + out = queue.get(timeout=0.001) + except Empty: + pass + return out + + @staticmethod + def gaussian_kernel(size_kernel, sigma=1): + size_kernel = np.ceil(size_kernel) // 2 * 2 + 1 + x = np.arange(- np.floor(size_kernel[0] / 2), np.ceil(size_kernel[0] / 2), 1) + y = np.arange(- np.floor(size_kernel[1] / 2), np.ceil(size_kernel[1] / 2), 1) + if len(size_kernel) == 3: + z = np.arange(- np.floor(size_kernel[2] / 2), np.ceil(size_kernel[2] / 2), 1) + xx, yy, zz = np.meshgrid(x, y, z) + kernel = np.exp(-(xx ** 2 + yy ** 2 + zz ** 2) / (2 * sigma ** 2)) + else: + xx, yy = np.meshgrid(x, y) + kernel = np.exp(-(xx ** 2 + yy ** 2) / (2 * sigma ** 2)) + return kernel + + def update_settings(self): + new_params = self.get_last_entry(self.scanning_parameters_queue) + if new_params is not None: + self.scanning_parameters = new_params + + def calculate_fov(self): + # calculate pix per millimeters + # formula: width FOV (microns) = 167.789 * Voltage + conv_fact = 167.789 + w_fov = conv_fact * self.scanning_parameters.voltage_x + return (self.scanning_parameters.n_x / w_fov) / 1000 + + def save_reference(self, raw_reference): + n_planes = self.reference.shape[1] + for plane in range(n_planes): + fl.save( + Path(self.save_parameters.output_dir) + / "anatomy/{:04d}.h5".format(plane), + {"stack_4D": raw_reference[:, plane, :, :]}, + compression="blosc", + ) \ No newline at end of file diff --git a/twop/gui.py b/twop/gui.py index 851b26c..e020f6c 100644 --- a/twop/gui.py +++ b/twop/gui.py @@ -11,7 +11,7 @@ QFileDialog, QCheckBox, ) -from state import ExperimentState, ScanningParameters, frame_duration +from twop.state import ExperimentState, ScanningParameters, frame_duration from twop.objective_motor_sliders import MotionControlXYZ import pyqtgraph as pg @@ -51,6 +51,7 @@ def __init__(self, state: ExperimentState): self.startstop_button = QPushButton() self.set_saving() self.chk_pause = QCheckBox("Pause after experiment") + self.chk_drift_corr = QCheckBox("Drift Correction") self.stack_progress = QProgressBar() self.plane_progress = QProgressBar() self.plane_progress.setFormat("Frame %v of %m") @@ -62,6 +63,7 @@ def __init__(self, state: ExperimentState): self.layout().addWidget(self.save_location_button) self.layout().addWidget(self.startstop_button) self.layout().addWidget(self.chk_pause) + self.layout().addWidget(self.chk_drift_corr) self.layout().addWidget(self.plane_progress) self.layout().addWidget(self.stack_progress) @@ -78,6 +80,11 @@ def set_notsaving(self): ) def toggle_start(self): + if self.chk_drift_corr.isChecked() is True: + self.state.reference_event.set() + else: + self.state.reference_event.clear() + if self.state.saving: self.state.end_experiment(force=True) self.set_saving() @@ -86,6 +93,8 @@ def toggle_start(self): if self.state.start_experiment(): self.set_notsaving() + + def set_locationbutton(self): pathtext = self.state.experiment_settings.save_dir # check if there is a stack in this location @@ -187,6 +196,16 @@ def toggle_pause(self): self.update_button() +class ReferenceWidget(QWidget): + def __init__(self, state: ExperimentState): + self.state = state + super().__init__() + self.reference_layout = QVBoxLayout() + self.reference_settings_gui = ParameterGui(self.state.reference_settings) + self.reference_layout.addWidget(self.reference_settings_gui) + self.setLayout(self.reference_layout) + + class TwopViewer(QMainWindow): def __init__(self): super().__init__() @@ -199,13 +218,18 @@ def __init__(self): self.scanning_widget = ScanningWidget(self.state) self.experiment_widget = ExperimentControl(self.state) + self.reference_widget = ReferenceWidget(self.state) - self.motor_control_slider = MotionControlXYZ(self.state.motors) + self.motor_control_slider = MotionControlXYZ(self.state.input_queues, self.state.output_queues) self.addDockWidget( Qt.LeftDockWidgetArea, DockedWidget(widget=self.scanning_widget, title="Scanning settings"), ) + self.addDockWidget( + Qt.LeftDockWidgetArea, + DockedWidget(widget=self.reference_widget, title="Reference settings"), + ) self.addDockWidget( Qt.RightDockWidgetArea, DockedWidget(widget=self.motor_control_slider, title="Stage control"), diff --git a/twop/objective_motor.py b/twop/objective_motor.py index 0cfbe3a..95cfa15 100644 --- a/twop/objective_motor.py +++ b/twop/objective_motor.py @@ -1,108 +1,137 @@ import pyvisa +from multiprocessing import Process +from queue import Empty + + +class MotorMaster(Process): + + def __init__( + self, + motors, + input_queues, + output_queues, + close_setup_event, + ): + super().__init__() + + self.input_queues = input_queues + self.output_queues = output_queues + self.close_setup_event = close_setup_event + self.motors = motors + self.positions = dict.fromkeys(motors) + self.motors_running = True + self.get_positions() + + def run(self) -> None: + while not self.close_setup_event.is_set(): + self.get_positions() + self.move_motors() + self.close_setups() + + def close_setups(self): + for axis in self.motors.keys(): + self.motors[axis].end_session() + + def get_positions(self): + for axis in self.motors.keys(): + actual_pos = self.motors[axis].get_position() + if actual_pos is not None: + self.positions[axis] = actual_pos + self.output_queues[axis].put(actual_pos) + + def move_motors(self): + for axis in self.motors.keys(): + package = self.get_last_entry(self.input_queues[axis]) + if package: + mov_value = package[0] + mov_type = package[1].name + empty_queue = False + else: + empty_queue = True + + if empty_queue is False: + if mov_type == "relative": + self.motors[axis].move_rel(mov_value) + elif mov_type == "absolute": + self.motors[axis].move_abs(mov_value) + + @staticmethod + def get_last_entry(queue): + out = tuple() + while True: + try: + out = queue.get(timeout=0.001) + except Empty: + break + return out class MotorControl: + def __init__( - self, - port, - baudrate=921600, - parity=pyvisa.constants.Parity.none, - encoding="ascii", - axes=None, + self, + port, + baudrate=921600, + parity=pyvisa.constants.Parity.none, + encoding="ascii", + axis=None, ): self.baudrate = baudrate self.parity = parity self.encoding = encoding self.port = port - axes = self.find_axis(axes) + axes = self.find_axis(axis) self.axes = str(axes) + self.x = 0 + self.y = 0 self.home_pos = None - rm = pyvisa.ResourceManager() - self.motor = rm.open_resource( - port, baud_rate=baudrate, parity=parity, encoding=encoding, timeout=10 - ) + self.motor = None self.start_session() self.connection = True def get_position(self): - input_m = self.axes + "TP" - try: - output = self.motor.query(input_m) - try: - output = [float(s) for s in output.split(",")] - except ValueError: - print("Got ", output, "from motor, what to do?") - return output[0] - - except pyvisa.VisaIOError: - print(f"Error get position axes number {self.axes} ") - return None - - def move_abs(self, coordinate): - coordinate = str(coordinate) - command = self.axes + "PA" + coordinate - self.execute_motor(command) + if self.axes == 1: + pos = self.x + elif self.axes == 2: + pos = self.y + else: + pos = 0 + return pos def move_rel(self, displacement=0.0): - displacement = str(displacement) - command = self.axes + "PR" + displacement - self.execute_motor(command) + if self.axes == 1: + self.x = self.x + displacement + elif self.axes == 2: + self.y = self.y + displacement + else: + pass def set_units(self, units): if units == "mm": units = 2 elif units == "um": units = 3 - command = self.axes + "SN" + str(units) - self.execute_motor(command) def define_home(self): self.home_pos = self.get_position() def go_home(self): - command = self.axes + "OR" + str(2) - self.execute_motor(command) + pass def execute_motor(self, command): - try: - self.motor.query(command) - except pyvisa.VisaIOError: - pass + print("motor", self.axes, "moved to:", command) def start_session(self): # motor on - command = self.axes + "MO" - self.execute_motor(command) - - # set trajectory mode to trapezoidal - command = self.axes + "TJ" + str(1) - self.execute_motor(command) - - # set jog high speed to 0.2 for x,y or to 0.5 for z - if self.axes == "1" or self.axes == "2": - command = self.axes + "JH" + str(0.2) - self.execute_motor(command) - elif self.axes == "3": - command = self.axes + "JH" + str(0.5) - self.execute_motor(command) - - # set jog low speed to 0.01 - command = self.axes + "TW" + str(0.01) - self.execute_motor(command) - - # define home position - self.define_home() - - # set mm as unit + print("motor", self.axes, "on") self.set_units("mm") def end_session(self): # motor off - command = self.axes + "MF" - self.execute_motor(command) + # command = self.axes + "MF" + # self.execute_motor(command) # close connection - self.motor.close() - self.connection = False + print("motor", self.axes, "off") @staticmethod def find_axis(axes): diff --git a/twop/objective_motor_sliders.py b/twop/objective_motor_sliders.py index dd66554..55b544a 100644 --- a/twop/objective_motor_sliders.py +++ b/twop/objective_motor_sliders.py @@ -11,23 +11,34 @@ from PyQt5.QtGui import QColor from PyQt5.QtCore import Qt, pyqtSignal, QPointF, QTimer +from queue import Empty +from enum import Enum + + +class MovementType(Enum): + absolute = True + relative = False class MotionControlXYZ(QWidget): - def __init__(self, motors): + def __init__(self, input_queues, output_queues): super().__init__() self.setLayout(QGridLayout()) - for key, value in motors.items(): - wid = MotorSlider(name=key, motor=value) + for axis in input_queues.keys(): + wid = MotorSlider(name=axis, + input_queue=input_queues[axis], + output_queue=output_queues[axis] + ) self.layout().addWidget(wid) class PrecisionSingleSliderMotorControl(PrecisionSingleSlider): - def __init__(self, *args, motor=None, pos=None, **kwargs): + def __init__(self, *args, input_queue=None, output_queue=None, pos=None, **kwargs): super().__init__(*args, **kwargs) self.pos = pos - self.motor = motor + self.input_queue = input_queue + self.output_queue = output_queue self.axes_pos = 0 self.indicator_color = QColor(178, 0, 0) @@ -60,7 +71,7 @@ class MotorSlider(QWidget): sig_changed = pyqtSignal(float) sig_end_session = pyqtSignal() - def __init__(self, motor=None, move_limit_low=-3, move_limit_high=3, name=""): + def __init__(self, input_queue=None, output_queue=None, move_limit_low=-3, move_limit_high=3, name=""): super().__init__() self.name = name self.grid_layout = QGridLayout() @@ -68,13 +79,15 @@ def __init__(self, motor=None, move_limit_low=-3, move_limit_high=3, name=""): self.grid_layout.setContentsMargins(0, 0, 0, 0) self.spin_val_desired_pos = QDoubleSpinBox() self.spin_val_actual_pos = QDoubleSpinBox() - - value = motor.home_pos + self.input_queue = input_queue + self.output_queue = output_queue + self.mov_type = MovementType(True) + value = self.output_queue.get(timeout=0.001) min_range = value + move_limit_low max_range = value + move_limit_high self.slider = PrecisionSingleSliderMotorControl( - default_value=value, min=min_range, max=max_range, pos=value, motor=motor + default_value=value, min=min_range, max=max_range, pos=value, input_queue=None, output_queue=None, ) for spin_val in [self.spin_val_actual_pos, self.spin_val_desired_pos]: spin_val.setRange(min_range, max_range) @@ -95,27 +108,29 @@ def __init__(self, motor=None, move_limit_low=-3, move_limit_high=3, name=""): self.setLayout(self.grid_layout) self.slider.sig_changed.connect(self.update_values) - self.sig_changed.connect(self.slider.motor.move_abs) self._timer_painter = QTimer(self) self._timer_painter.timeout.connect(self.update_actual_pos) self._timer_painter.start() def update_actual_pos(self): - if self.slider.motor.connection is True: - pos = self.slider.motor.get_position() - if pos is not None: - self.spin_val_actual_pos.setValue(pos) - self.slider.axes_pos = pos + while True: + try: + pos = self.output_queue.get(timeout=0.001) + except Empty: + break + self.spin_val_actual_pos.setValue(pos) + self.slider.axes_pos = pos def update_values(self, val): self.spin_val_desired_pos.setValue(val) - self.sig_changed.emit(val) + self.input_queue.put((val, self.mov_type)) def update_slider(self, new_val): self.slider.pos = new_val self.slider.update() - self.sig_changed.emit(new_val) + + self.input_queue.put((new_val, self.mov_type)) def update_external(self, new_val): self.slider.pos = new_val diff --git a/twop/power_control.py b/twop/power_control.py index dddc553..293cdbd 100644 --- a/twop/power_control.py +++ b/twop/power_control.py @@ -19,14 +19,7 @@ def __init__( self.encoding = encoding self.port = port self.device = device - rm = pyvisa.ResourceManager() - self.rotatory_stage = rm.open_resource( - port, - baud_rate=self.baudrate, - parity=self.parity, - encoding=self.encoding, - open_timeout=1, - ) + self.rotatory_stage = None self.execute_home_search() def get_position(self): @@ -37,24 +30,24 @@ def get_position(self): def execute_home_search(self): input_m = str(self.device) + "OR" - self.rotatory_stage.write(input_m) + print("power_control home search...") def get_upper_bound(self): upper_bound = "" input_m = str(self.device) + "SR" + upper_bound - upper_bound = self.rotatory_stage.query(input_m) - return upper_bound + print("power_control get upper bound...") + return 100 def get_lower_bound(self): lower_bound = "" input_m = str(self.device) + "SL" + lower_bound - lower_bound = self.rotatory_stage.query(input_m) - return lower_bound + print("power_control get lower bound...") + return 0 def move_abs(self, target_power_percent=0): target_position = self.unit_transformer(target_power_percent) input_m = str(self.device) + "PA" + str(target_position) - self.rotatory_stage.write(input_m) + print("Set power to:", target_position) def terminate_connection(self): self.rotatory_stage.close() diff --git a/twop/scanning.py b/twop/scanning.py index 89bf8e6..863f68e 100644 --- a/twop/scanning.py +++ b/twop/scanning.py @@ -60,9 +60,10 @@ def compute_waveform(sp: ScanningParameters): class Scanner(Process): - def __init__(self, experiment_start_event, duration_queue, max_queuesize=200): + def __init__(self, experiment_start_event, duration_queue, max_queuesize=200, correction=None): super().__init__() self.data_queue = ArrayQueue(max_mbytes=max_queuesize) + self.data_queue_copy = ArrayQueue(max_mbytes=max_queuesize) self.parameter_queue = Queue() self.stop_event = Event() self.experiment_start_event = experiment_start_event @@ -70,6 +71,9 @@ def __init__(self, experiment_start_event, duration_queue, max_queuesize=200): self.new_parameters = copy(self.scanning_parameters) self.duration_queue = duration_queue self.n_frames_queue = Queue() + self.correction_event = correction + self.correction_status = False + self.corrector_queue = Queue() def run(self): self.compute_scan_parameters() @@ -188,6 +192,8 @@ def scan_loop(self, read_task, write_task): print(e) break self.data_queue.put(self.read_buffer[0, :]) + if self.correction_status is True: + self.data_queue_copy.put(self.read_buffer[0, :]) # if new parameters have been received and changed, update # them, breaking out of the loop if the experiment is not running try: @@ -234,13 +240,17 @@ def run_scanning(self): and self.scanning_parameters.scanning_state == ScanningState.PAUSED ): toggle_shutter = True - + if self.correction_event.is_set(): + self.correction_status = True + else: + self.correction_status = False self.scanning_parameters = self.new_parameters self.compute_scan_parameters() with Task() as write_task, Task() as read_task, Task() as shutter_task: self.setup_tasks(read_task, write_task, shutter_task) if self.scanning_parameters.reset_shutter or toggle_shutter: self.toggle_shutter(shutter_task) + pass if self.scanning_parameters.scanning_state == ScanningState.PAUSED: self.pause_loop() else: diff --git a/twop/state.py b/twop/state.py index afb7850..c755bac 100644 --- a/twop/state.py +++ b/twop/state.py @@ -12,7 +12,7 @@ from streaming_save import StackSaver, SavingParameters, SavingStatus from arrayqueues.shared_arrays import ArrayQueue from queue import Empty -from twop.objective_motor import MotorControl +from twop.objective_motor import MotorMaster, MotorControl from twop.external_communication import ZMQcomm from twop.power_control import LaserPowerControl from math import sqrt @@ -21,7 +21,8 @@ from enum import Enum from time import sleep from sequence_diagram import SequenceDiagram - +from twop.drift_correction import * +from twop.objective_motor_sliders import MovementType class ExperimentSettings(ParametrizedQt): def __init__(self): @@ -29,9 +30,7 @@ def __init__(self): self.name = "recording" self.n_planes = Param(1, (1, 500)) self.dz = Param(1.0, (0.1, 20.0), unit="um") - self.save_dir = Param(r"C:\Users\portugueslab\Desktop\test\python", gui=False) - self.notification_email = Param("None") - self.notify_every_n_planes = Param(3, (1, 1000)) + self.save_dir = Param(r"C:\Users\epaoli\Desktop\sim_exp\results", gui=False) class ScanningSettings(ParametrizedQt): @@ -39,7 +38,7 @@ def __init__(self): super().__init__() self.name = "scanning" self.aspect_ratio = Param(1.0, (0.2, 5.0)) - self.voltage = Param(3.0, (0.2, 4.0)) + self.voltage = Param(3.0, (0.3, 4.0)) self.framerate = Param(2.0, (0.1, 10.0)) self.reset_shutter = Param(False) self.binning = Param(10, (1, 50)) @@ -109,17 +108,20 @@ def __init__(self, diagnostics=False): self.experiment_start_event = Event() self.scanning_settings = ScanningSettings() self.experiment_settings = ExperimentSettings() + self.reference_settings = ReferenceSettings() self.pause_after = False self.parameter_tree = ParameterTree() self.parameter_tree.add(self.scanning_settings) self.parameter_tree.add(self.experiment_settings) + self.parameter_tree.add(self.reference_settings) self.end_event = Event() self.external_sync = ZMQcomm() self.duration_queue = Queue() + self.correction_event = Event() self.scanner = Scanner( - self.experiment_start_event, duration_queue=self.duration_queue + self.experiment_start_event, duration_queue=self.duration_queue, correction=self.correction_event ) self.scanning_parameters = None self.reconstructor = ImageReconstructor( @@ -127,23 +129,40 @@ def __init__(self, diagnostics=False): ) self.save_queue = ArrayQueue(max_mbytes=800) + self.reference_event = Event() + self.reference_params = None + self.reference_queue = ArrayQueue(max_mbytes=800) self.saver = StackSaver( - self.scanner.stop_event, self.save_queue, self.scanner.n_frames_queue + self.scanner.stop_event, self.save_queue, self.scanner.n_frames_queue, self.reference_event, + self.reference_queue ) self.save_status: Optional[SavingStatus] = None + self.input_queues = {"x": Queue(), "y": Queue(), "z": Queue()} + self.output_queues = self.input_queues.copy() # not sure can be done with queue class + self.close_setup_event = Event() + self.move_type = MovementType(True) self.motors = dict() - self.motors["x"] = MotorControl("COM6", axes="x") - self.motors["y"] = MotorControl("COM6", axes="y") - self.motors["z"] = MotorControl("COM6", axes="z") + for axis in ["x", "y", "z"]: + self.motors[axis] = MotorControl("COM6", axis=axis) + self.master_motor = MotorMaster(self.motors, self.input_queues, + self.output_queues, self.close_setup_event) self.power_controller = LaserPowerControl() + self.corrector = Corrector(self.reference_event, self.experiment_start_event, self.scanner.stop_event, + self.correction_event, self.saver.ref_queue, self.scanner.scanning_parameters, + self.scanner.corrector_queue, self.scanner.data_queue_copy, + self.input_queues, self.output_queues, self.saver.save_parameters + ) self.scanning_settings.sig_param_changed.connect(self.send_scan_params) self.scanning_settings.sig_param_changed.connect(self.send_save_params) + self.reference_settings.sig_param_changed.connect(self.send_reference_params) + self.experiment_settings.sig_param_changed.connect(self.send_reference_params) self.scanner.start() self.reconstructor.start() + self.master_motor.start() self.saver.start() + self.corrector.start() self.open_setup() - self.paused = False @property @@ -152,13 +171,18 @@ def saving(self): def open_setup(self): self.send_scan_params() + self.send_reference_params() def start_experiment(self, first_plane=True): - duration = self.external_sync.send(self.parameter_tree.serialize()) - if duration is None: - self.restart_scanning() - return False - self.duration_queue.put(duration) + if not self.reference_event.is_set(): + duration = 8 * 60 # change + if duration is None: + self.restart_scanning() + return False + self.duration_queue.put(duration) + else: + duration = self.reference_params.n_frames_ref * (1 / self.scanning_settings.framerate) + self.duration_queue.put(duration) params_to_send = convert_params(self.scanning_settings) params_to_send.scanning_state = ScanningState.EXPERIMENT_RUNNING self.scanner.parameter_queue.put(params_to_send) @@ -193,7 +217,10 @@ def pause_scanning(self): self.paused = True def advance_plane(self): - self.motors["z"].move_rel(self.experiment_settings.dz / 1000) + if not self.reference_event.is_set(): + self.input_queues["z"].put((self.experiment_settings.dz / 1000, self.move_type)) + else: + self.input_queues["z"].put((self.reference_params.dz / 1000, self.move_type)) sleep(0.2) self.start_experiment(first_plane=False) @@ -202,9 +229,10 @@ def close_setup(self): end all parallel processes, close all communication channels """ - for motor in self.motors.values(): - motor.end_session() + # for motor in self.motors.values(): + # motor.end_session() self.power_controller.terminate_connection() + self.close_setup_event.set() self.scanner.stop_event.set() self.end_event.set() self.scanner.join() @@ -232,15 +260,32 @@ def send_scan_params(self): self.sig_scanning_changed.emit() def send_save_params(self): - self.saver.saving_parameter_queue.put( - SavingParameters( - output_dir=Path(self.experiment_settings.save_dir), - plane_size=(self.scanning_parameters.n_x, self.scanning_parameters.n_y), - n_z=self.experiment_settings.n_planes, - notification_email=self.experiment_settings.notification_email, - notification_frequency=self.experiment_settings.notify_every_n_planes + if not self.reference_event.is_set(): + self.saver.saving_parameter_queue.put( + SavingParameters( + output_dir=Path(self.experiment_settings.save_dir), + plane_size=(self.scanner.read_task.plane_size[0], self.scanner.read_task.plane_size[1]), + n_z=self.experiment_settings.n_planes, + ) ) - ) + else: + self.saver.saving_parameter_queue.put( + SavingParameters( + output_dir=Path(self.experiment_settings.save_dir), + plane_size=(self.scanner.read_task.plane_size[0], self.scanner.read_task.plane_size[1]), + n_z=self.reference_settings.dz, + ) + ) + + def send_reference_params(self): + param_to_send = convert_reference_params(self.reference_settings) + if self.reference_event.is_set(): + n_planes = (self.reference_params.extra_planes * 2) + self.experiment_settings.n_planes + else: + n_planes = self.experiment_settings.n_planes + param_to_send.n_planes = n_planes + self.reference_params = param_to_send + self.corrector.reference_param_queue.put(param_to_send) def get_save_status(self) -> Optional[SavingStatus]: try: diff --git a/twop/streaming_save.py b/twop/streaming_save.py index bb510d2..7e47aaa 100644 --- a/twop/streaming_save.py +++ b/twop/streaming_save.py @@ -31,7 +31,7 @@ class SavingStatus: class StackSaver(Process): - def __init__(self, stop_signal, data_queue, n_frames_queue): + def __init__(self, stop_signal, data_queue, n_frames_queue, ref_event, ref_queue): super().__init__() self.stop_signal = stop_signal self.data_queue = data_queue @@ -43,8 +43,11 @@ def __init__(self, stop_signal, data_queue, n_frames_queue): self.i_in_plane = 0 self.i_block = 0 self.current_data = None + self.reference = None self.saved_status_queue = Queue() self.dtype = np.float32 + self.ref_event = ref_event + self.ref_queue = ref_queue def run(self): while not self.stop_signal.is_set(): @@ -152,19 +155,20 @@ def finalize_dataset(self): ) def complete_plane(self): - fl.save( - Path(self.save_parameters.output_dir) - / "original/{:04d}.h5".format(self.i_block), - {"stack_4D": self.current_data}, - compression="blosc", - ) - self.i_block += 1 - - if self.i_block % self.save_parameters.notification_frequency == 0 and \ - self.save_parameters.notification_email != "None": - self.send_email_update(frame=self.current_data[self.i_in_plane - 1, 0, :, :]) - + if not self.ref_event.is_set(): + fl.save( + Path(self.save_parameters.output_dir) + / "original/{:04d}.h5".format(self.i_block), + {"stack_4D": self.current_data}, + compression="blosc", + ) + if self.i_block % self.save_parameters.notification_frequency == 0 and \ + self.save_parameters.notification_email != "None": + self.send_email_update(frame=self.current_data[self.i_in_plane - 1, 0, :, :]) + else: + self.fill_reference() self.i_in_plane = 0 + self.i_block += 1 def send_email_update(self, frame=None, end=False): sender_email = "fishgitbot@gmail.com" @@ -217,9 +221,23 @@ def send_email_update(self, frame=None, end=False): except OSError: pass - def receive_save_parameters(self): try: self.save_parameters = self.saving_parameter_queue.get(timeout=0.001) except Empty: pass + + def fill_reference(self): + if self.i_block == 0: + self.reference = np.zeros(( + int(self.save_parameters.n_t), + int(self.save_parameters.n_z), + self.current_data.shape[2], + self.current_data.shape[3])) + + self.reference[:, self.i_block, :, :] = self.current_data[:,0,:,:] + if self.i_block == self.save_parameters.n_z - 1: + self.send_reference() + + def send_reference(self): + self.ref_queue.put(self.reference)