# -*- coding: utf-8 -*-
"""
Created on 16 Aug 2019

@author: Thera Pals

Copyright © 2019 Thera Pals, Delmic

This file is part of Odemis.

Delmic Acquisition Software is free software: you can redistribute it and/or modify it under the terms of the GNU
General Public License as published by the Free Software Foundation, either version 2 of the License, or (at your
option) any later version.

Delmic Acquisition Software is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
more details.

You should have received a copy of the GNU General Public License along with Delmic Acquisition Software. If not, see
http://www.gnu.org/licenses/.
"""
from __future__ import division, print_function

import base64
import logging
import threading
import time
from concurrent.futures import CancelledError

import msgpack
import msgpack_numpy as m
from Pyro5.api import Proxy

from odemis import model
from odemis import util
from odemis.model import CancellableThreadPoolExecutor, HwError, isasync, CancellableFuture

XT_RUN = "run"
XT_STOP = "stop"
XT_CANCEL = "cancel"


class SEM(model.HwComponent):
    """
    Class to communicate with a Microscope server via the ZeroRPC protocol.
    """

    def __init__(self, name, role, children, address, daemon=None,
                 **kwargs):
        """
        Parameters
        ----------
        address: str
            server address and port of the Microscope server, e.g. "PYRO:Microscope@localhost:4242"
        timeout: float
            Time in seconds the client should wait for a response from the server.
        """

        model.HwComponent.__init__(self, name, role, daemon=daemon, **kwargs)
        self._proxy_access = threading.Lock()
        try:
            self.server = Proxy(address)
            self.server._pyroTimeout = 30  # seconds
            self._swVersion = self.server.get_software_version()
            self._hwVersion = self.server.get_hardware_version()
        except Exception as err:
            raise HwError("Failed to connect to XT server '%s'. Check that the "
                          "uri is correct and XT server is"
                          " connected to the network. %s" % (address, err))

        # create the scanner child
        try:
            kwargs = children["scanner"]
        except (KeyError, TypeError):
            raise KeyError("SEM was not given a 'scanner' child")
        self._scanner = Scanner(parent=self, daemon=daemon, **kwargs)
        self.children.value.add(self._scanner)

        # create the stage child, if requested
        if "stage" in children:
            ckwargs = children["stage"]
            self._stage = Stage(parent=self, daemon=daemon, **ckwargs)
            self.children.value.add(self._stage)

        # create a focuser, if requested
        if "focus" in children:
            ckwargs = children["focus"]
            self._focus = Focus(parent=self, daemon=daemon, **ckwargs)
            self.children.value.add(self._focus)

    def list_available_channels(self):
        """
        List all available channels and their current state as a dict.

        Returns
        -------
        available channels: dict
            A dict of the names of the available channels as keys and the corresponding channel state as values.
        """
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.list_available_channels()

    def move_stage(self, position, rel=False):
        """
        Move the stage the given position in meters. This is non-blocking. Throws an error when the requested position
        is out of range.

        Parameters
        ----------
        position: dict(string->float)
            Absolute or relative position to move the stage to per axes in m. Axes are 'x' and 'y'.
        rel: boolean
            If True the staged is moved relative to the current position of the stage, by the distance specified in
            position. If False the stage is moved to the absolute position.
        """
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            self.server.move_stage(position, rel)

    def stage_is_moving(self):
        """Returns: (bool) True if the stage is moving and False if the stage is not moving."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.stage_is_moving()

    def stop_stage_movement(self):
        """Stop the movement of the stage."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            self.server.stop_stage_movement()

    def get_stage_position(self):
        """
        Returns: (dict) the axes of the stage as keys with their corresponding position.
        """
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.get_stage_position()

    def stage_info(self):
        """Returns: (dict) the unit and range of the stage position."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.stage_info()

    def acquire_image(self, channel_name):
        """
        Acquire an image observed via the currently set channel. Note: the channel needs to be stopped before an image
        can be acquired. To acquire multiple consecutive images the channel needs to be started and stopped. This
        causes the acquisition speed to be approximately 1 fps.

        Returns
        -------
        image: numpy array
            The acquired image.
        """
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            x_enc = self.server.acquire_image(channel_name)
            x_dec = base64.b64decode(x_enc['data'])
            x_rec = msgpack.unpackb(x_dec, object_hook=m.decode)
            return x_rec

    def set_scan_mode(self, mode):
        """
        Set the scan mode.
        Parameters
        ----------
        mode: str
            Name of desired scan mode, one of: unknown, external, full_frame, spot, or line.
        """
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            self.server.set_scan_mode(mode)

    def set_selected_area(self, start_position, size):
        """
        Specify a selected area in the scan field area.

        Parameters
        ----------
        start_position: (tuple of int)
            (x, y) of where the area starts in pixel, (0,0) is at the top left.
        size: (tuple of int)
            (width, height) of the size in pixel.
        """
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            self.server.set_selected_area(start_position, size)

    def get_selected_area(self):
        """
        Returns
        -------
        x, y, width, height: pixels
            The current selected area. If selected area is not active it returns the stored selected area.
        """
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            x, y, width, height = self.server.get_selected_area()
            return x, y, width, height

    def selected_area_info(self):
        """Returns: (dict) the unit and range of set selected area."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.selected_area_info()

    def reset_selected_area(self):
        """Reset the selected area to select the entire image."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            self.server.reset_selected_area()

    def set_scanning_size(self, x):
        """
        Set the size of the to be scanned area (aka field of view or the size, which can be scanned with the current
        settings).

        Parameters
        ----------
        x: (float)
            size for X in meters.
        """
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            self.server.set_scanning_size(x)

    def get_scanning_size(self):
        """
        Returns: (tuple of floats) x and y scanning size in meters.
        """
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.get_scanning_size()

    def scanning_size_info(self):
        """Returns: (dict) the scanning size unit and range."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.scanning_size_info()

    def set_ebeam_spotsize(self, spotsize):
        """
        Setting the spot size of the ebeam.
        Parameters
        ----------
        spotsize: float
            desired spotsize, unitless
        """
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            self.server.set_ebeam_spotsize(spotsize)

    def get_ebeam_spotsize(self):
        """Returns: (float) the current spotsize of the electron beam (unitless)."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.get_ebeam_spotsize()

    def spotsize_info(self):
        """Returns: (dict) the unit and range of the spotsize. Unit is None means the spotsize is unitless."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.spotsize_info()

    def set_dwell_time(self, dwell_time):
        """

        Parameters
        ----------
        dwell_time: float
            dwell time in seconds
        """
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            self.server.set_dwell_time(dwell_time)

    def get_dwell_time(self):
        """Returns: (float) the dwell time in seconds."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.get_dwell_time()

    def dwell_time_info(self):
        """Returns: (dict) range of the dwell time and corresponding unit."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.dwell_time_info()

    def set_ht_voltage(self, voltage):
        """
        Set the high voltage.

        Parameters
        ----------
        voltage: float
            Desired high voltage value in volt.

        """
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            self.server.set_ht_voltage(voltage)

    def get_ht_voltage(self):
        """Returns: (float) the HT Voltage in volt."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.get_ht_voltage()

    def ht_voltage_info(self):
        """Returns: (dict) the unit and range of the HT Voltage."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.ht_voltage_info()

    def blank_beam(self):
        """Blank the electron beam."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            self.server.blank_beam()

    def unblank_beam(self):
        """Unblank the electron beam."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            self.server.unblank_beam()

    def beam_is_blanked(self):
        """Returns: (bool) True if the beam is blanked and False if the beam is not blanked."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.beam_is_blanked()

    def pump(self):
        """Pump the microscope's chamber. Note that pumping takes some time. This is blocking."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            self.server.pump()

    def get_vacuum_state(self):
        """Returns: (string) the vacuum state of the microscope chamber to see if it is pumped or vented."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.get_vacuum_state()

    def vent(self):
        """Vent the microscope's chamber. Note that venting takes time (appr. 3 minutes). This is blocking."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            self.server.vent()

    def get_pressure(self):
        """Returns: (float) the chamber pressure in pascal."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.get_pressure()

    def home_stage(self):
        """Home stage asynchronously. This is non-blocking."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            self.server.home_stage()

    def is_homed(self):
        """Returns: (bool) True if the stage is homed and False otherwise."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.is_homed()

    def set_channel_state(self, name, state):
        """
        Stop or start running the channel. This is non-blocking.

        Parameters
        ----------
        name: str
            name of channel.
        state: "run" or "stop"
            desired state of the channel.
        """
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            self.server.set_channel_state(name, state)

    def wait_for_state_changed(self, desired_state, name, timeout=10):
        """
        Wait until the state of the channel has changed to the desired state, if it has not changed after a certain
        timeout an error will be raised.

        Parameters
        ----------
        desired_state: "run", "stop" or "cancel"
            The state the channel should change into.
        name: str
            name of channel.
        timeout: int
            Amount of time in seconds to wait until the channel state has changed.
        """
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            self.server.wait_for_state_changed(desired_state, name, timeout)

    def get_channel_state(self, name):
        """Returns: (str) the state of the channel: "run", "stop" or "cancel"."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.get_channel_state(name)

    def get_free_working_distance(self):
        """Returns: (float) the free working distance in meters."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.get_free_working_distance()

    def set_free_working_distance(self, free_working_distance):
        """
        Set the free working distance.
        Parameters
        ----------
        free_working_distance: float
            free working distance in meters.
        """
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            self.server.set_free_working_distance(free_working_distance)

    def fwd_info(self):
        """Returns the unit and range of the free working distance."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.fwd_info()

    def get_fwd_follows_z(self):
        """
        Returns: (bool) True if Z follows free working distance.
        When Z follows FWD and Z-axis of stage moves, FWD is updated to keep image in focus.
        """
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.get_fwd_follows_z()

    def set_fwd_follows_z(self, follow_z):
        """
        Set if z should follow the free working distance. When Z follows FWD and Z-axis of stage moves, FWD is updated
        to keep image in focus.
        Parameters
        ---------
        follow_z: bool
            True if Z should follow free working distance.
        """
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            self.server.set_fwd_follows_z(follow_z)

    def set_autofocusing(self, name, state):
        """
        Set the state of autofocus, beam must be turned on. This is non-blocking.

        Parameters
        ----------
        name: str
            Name of one of the electron channels, the channel must be running.
        state: "start", "cancel" or "stop"
            If state is start, autofocus starts. States cancel and stop both stop the autofocusing. Some microscopes
            might need stop, while others need cancel. The Apreo system requires stop.
        """
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            self.server.set_autofocusing(name, state)

    def is_autofocusing(self):
        """Returns: (bool) True if autofocus is running and False if autofocus is not running."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.is_autofocusing()

    def set_auto_contrast_brightness(self, name, state):
        """
        Set the state of auto contrast brightness. This is non-blocking.

        Parameters
        ----------
        name: str
            Name of one of the electron channels.
        state: "start", "cancel" or "stop"
            If state is start, auto contrast brightness starts. States cancel and stop both stop the auto contrast
            brightness. Some microscopes might need stop, while others need cancel. The Apreo system requires stop.
        """
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            self.server.set_auto_contrast_brightness(name, state)

    def is_running_auto_contrast_brightness(self):
        """
        Returns: (bool) True if auto contrast brightness is running and False if auto contrast brightness is not
        running.
        """
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.is_running_auto_contrast_brightness()

    def get_beam_shift(self):
        """Returns: (float) the current beam shift x and y values in meters."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.get_beam_shift()

    def set_beam_shift(self, x_shift, y_shift):
        """Set the current beam shift values in meters."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            self.server.set_beam_shift(x_shift, y_shift)

    def beam_shift_info(self):
        """Returns: (dict) the unit and xy-range of the beam shift."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.beam_shift_info()

    def get_stigmator(self):
        """Returns: (float) the current stigmator x and y values."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.get_stigmator()

    def set_stigmator(self, x, y):
        """Set the current stigmator values."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            self.server.set_stigmator(x, y)

    def stigmator_info(self):
        """Returns: (dict) the unit and xy-range of the stigmator."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.stigmator_info()

    def get_rotation(self):
        """Returns: (float) the current rotation value in rad."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.get_rotation()

    def set_rotation(self, rotation):
        """Set the current rotation value in rad."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            self.server.set_rotation(rotation)

    def rotation_info(self):
        """Returns: (dict) the unit and range of the rotation."""
        with self._proxy_access:
            self.server._pyroClaimOwnership()
            return self.server.rotation_info()


class Scanner(model.Emitter):
    """
    This is an extension of the model.Emitter class. It contains Vigilant
    Attributes for magnification, accel voltage, blanking, spotsize, beam shift,
    rotation and dwell time. Whenever one of these attributes is changed, its
    setter also updates another value if needed.
    """

    def __init__(self, name, role, parent, hfw_nomag, **kwargs):
        model.Emitter.__init__(self, name, role, parent=parent, **kwargs)
        self._hfw_nomag = hfw_nomag

        dwell_time_info = self.parent.dwell_time_info()
        self.dwellTime = model.FloatContinuous(
            self.parent.get_dwell_time(),
            dwell_time_info["range"],
            unit=dwell_time_info["unit"],
            setter=self._setDwellTime)

        voltage_info = self.parent.ht_voltage_info()
        self.accelVoltage = model.FloatContinuous(
            self.parent.get_ht_voltage(),
            voltage_info["range"],
            unit=voltage_info["unit"],
            setter=self._setVoltage
        )

        self.blanker = model.BooleanVA(
            self.parent.beam_is_blanked(),
            setter=self._setBlanker)

        spotsize_info = self.parent.spotsize_info()
        self.spotSize = model.FloatContinuous(
            self.parent.get_ebeam_spotsize(),
            spotsize_info["range"],
            unit=spotsize_info["unit"],
            setter=self._setSpotSize)

        beam_shift_info = self.parent.beam_shift_info()
        range_x = beam_shift_info["range"]["x"]
        range_y = beam_shift_info["range"]["y"]
        self.beamShift = model.TupleContinuous(
            self.parent.get_beam_shift(),
            ((range_x[0], range_y[0]), (range_x[1], range_y[1])),
            cls=(int, float),
            unit=beam_shift_info["unit"],
            setter=self._setBeamShift)

        rotation_info = self.parent.rotation_info()
        self.rotation = model.FloatContinuous(
            self.parent.get_rotation(),
            rotation_info["range"],
            unit=rotation_info["unit"],
            setter=self._setRotation)

        scanning_size_info = self.parent.scanning_size_info()
        fov = self.parent.get_scanning_size()[0]
        self.horizontalFoV = model.FloatContinuous(
            fov,
            unit=scanning_size_info["unit"],
            range=scanning_size_info["range"]["x"],
            setter=self._setHorizontalFoV)

        mag = self._hfw_nomag / fov
        mag_range_max = self._hfw_nomag / scanning_size_info["range"]["x"][0]
        mag_range_min = self._hfw_nomag / scanning_size_info["range"]["x"][1]
        self.magnification = model.FloatContinuous(mag, unit="",
                                                   range=(mag_range_min, mag_range_max),
                                                   readonly=True)
        # To provide some rough idea of the step size when changing focus
        # Depends on the pixelSize, so will be updated whenever the HFW changes
        self.depthOfField = model.FloatContinuous(1e-6, range=(0, 1e3),
                                                  unit="m", readonly=True)
        self._updateDepthOfField()

        # Refresh regularly the values, from the hardware, starting from now
        self._updateSettings()
        self._va_poll = util.RepeatingTimer(5, self._updateSettings, "Settings polling")
        self._va_poll.start()

    def _updateSettings(self):
        """
        Read all the current settings from the SEM and reflects them on the VAs
        """
        logging.debug("Updating SEM settings")
        try:
            dwell_time = self.parent.get_dwell_time()
            if dwell_time != self.dwellTime.value:
                self.dwellTime._value = dwell_time
                self.dwellTime.notify(dwell_time)
            voltage = self.parent.get_ht_voltage()
            if voltage != self.accelVoltage.value:
                self.accelVoltage._value = voltage
                self.accelVoltage.notify(voltage)
            blanked = self.parent.beam_is_blanked()
            if blanked != self.blanker.value:
                self.blanker._value = blanked
                self.blanker.notify(blanked)
            spot_size = self.parent.get_ebeam_spotsize()
            if spot_size != self.spotSize.value:
                self.spotSize._value = spot_size
                self.spotSize.notify(spot_size)
            beam_shift = self.parent.get_beam_shift()
            if beam_shift != self.beamShift.value:
                self.beamShift._value = beam_shift
                self.beamShift.notify(beam_shift)
            rotation = self.parent.get_rotation()
            if rotation != self.rotation.value:
                self.rotation._value = rotation
                self.rotation.notify(rotation)
            fov = self.parent.get_scanning_size()[0]
            if fov != self.horizontalFoV.value:
                self.horizontalFoV._value = fov
                mag = self._hfw_nomag / fov
                self.magnification._value = mag
                self.horizontalFoV.notify(fov)
                self.magnification.notify(mag)
        except Exception:
            logging.exception("Unexpected failure when polling settings")

    def _setDwellTime(self, dwell_time):
        self.parent.set_dwell_time(dwell_time)
        return self.parent.get_dwell_time()

    def _setVoltage(self, voltage):
        self.parent.set_ht_voltage(voltage)
        return self.parent.get_ht_voltage()

    def _setBlanker(self, blank):
        """True if the the electron beam should blank, False if it should be unblanked."""
        if blank:
            self.parent.blank_beam()
        else:
            self.parent.unblank_beam()
        return self.parent.beam_is_blanked()

    def _setSpotSize(self, spotsize):
        self.parent.set_ebeam_spotsize(spotsize)
        return self.parent.get_ebeam_spotsize()

    def _setBeamShift(self, beam_shift):
        self.parent.set_beam_shift(*beam_shift)
        return self.parent.get_beam_shift()

    def _setRotation(self, rotation):
        self.parent.set_rotation(rotation)
        return self.parent.get_rotation()

    def _setHorizontalFoV(self, fov):
        self.parent.set_scanning_size(fov)
        fov = self.parent.get_scanning_size()[0]
        mag = self._hfw_nomag / fov
        self.magnification._value = mag
        self.magnification.notify(mag)
        self._updateDepthOfField()
        return fov

    def _updateDepthOfField(self):
        fov = self.horizontalFoV.value
        # Formula was determined by experimentation
        K = 100  # Magical constant that gives a not too bad depth of field
        dof = K * (fov / 1024)
        self.depthOfField._set_value(dof, force_write=True)


class Stage(model.Actuator):
    """
    This is an extension of the model.Actuator class. It provides functions for
    moving the TFS stage and updating the position.
    """

    def __init__(self, name, role, parent, rng=None, **kwargs):
        if rng is None:
            rng = {}
        stage_info = parent.stage_info()
        if "x" not in rng:
            rng["x"] = stage_info["range"]["x"]
        if "y" not in rng:
            rng["y"] = stage_info["range"]["y"]
        if "z" not in rng:
            rng["z"] = stage_info["range"]["z"]

        axes_def = {
            # Ranges are from the documentation
            "x": model.Axis(unit="m", range=rng["x"]),
            "y": model.Axis(unit="m", range=rng["y"]),
            "z": model.Axis(unit="m", range=rng["z"]),
        }

        model.Actuator.__init__(self, name, role, parent=parent, axes=axes_def,
                                **kwargs)
        # will take care of executing axis move asynchronously
        self._executor = CancellableThreadPoolExecutor(max_workers=1)  # one task at a time

        self.position = model.VigilantAttribute({}, unit=stage_info["unit"],
                                                readonly=True)
        self._updatePosition()

        # Refresh regularly the position
        self._pos_poll = util.RepeatingTimer(5, self._refreshPosition, "Position polling")
        self._pos_poll.start()

    def _updatePosition(self, raw_pos=None):
        """
        update the position VA
        raw_pos (dict str -> float): the position in mm (as received from the SEM)
        """
        if raw_pos is None:
            position = self.parent.get_stage_position()
            x, y, z = position["x"], position["y"], position["z"]
        else:
            x, y, z = raw_pos["x"], raw_pos["y"], raw_pos["z"]

        pos = {"x": x,
               "y": y,
               "z": z,
               }
        self.position._set_value(self._applyInversion(pos), force_write=True)

    def _refreshPosition(self):
        """
        Called regularly to update the current position
        """
        # We don't use the VA setters, to avoid sending back to the hardware a
        # set request
        logging.debug("Updating SEM stage position")
        try:
            self._updatePosition()
        except Exception:
            logging.exception("Unexpected failure when updating position")

    def _moveTo(self, future, pos, timeout=60):
        with future._moving_lock:
            try:
                if future._must_stop.is_set():
                    raise CancelledError()
                logging.debug("Moving to position {}".format(pos))
                self.parent.move_stage(pos, rel=False)
                time.sleep(0.5)

                # Wait until the move is over.
                # Don't check for future._must_stop because anyway the stage will
                # stop moving, and so it's nice to wait until we know the stage is
                # not moving.
                moving = True
                tstart = time.time()
                while moving:
                    pos = self.parent.get_stage_position()
                    moving = self.parent.stage_is_moving()
                    # Take the opportunity to update .position
                    self._updatePosition(pos)

                    if time.time() > tstart + timeout:
                        self.parent.stop_stage_movement()
                        logging.error("Timeout after submitting stage move. Aborting move.")
                        break

                    # Wait for 50ms so that we do not keep using the CPU all the time.
                    time.sleep(50e-3)

                # If it was cancelled, Abort() has stopped the stage before, and
                # we still have waited until the stage stopped moving. Now let
                # know the user that the move is not complete.
                if future._must_stop.is_set():
                    raise CancelledError()
            except Exception:
                if future._must_stop.is_set():
                    raise CancelledError()
                raise
            finally:
                future._was_stopped = True
                # Update the position, even if the move didn't entirely succeed
                self._updatePosition()

    def _doMoveRel(self, future, shift):
        pos = self.parent.get_stage_position()
        for k, v in shift.items():
            pos[k] += v

        target_pos = self._applyInversion(pos)
        # Check range (for the axes we are moving)
        for an in shift.keys():
            rng = self.axes[an].range
            p = target_pos[an]
            if not rng[0] <= p <= rng[1]:
                raise ValueError("Relative move would cause axis %s out of bound (%g m)" % (an, p))

        self._moveTo(future, pos)

    @isasync
    def moveRel(self, shift):
        """
        Shift the stage the given position in meters. This is non-blocking.
        Throws an error when the requested position is out of range.

        Parameters
        ----------
        shift: dict(string->float)
            Relative shift to move the stage to per axes in m. Axes are 'x' and 'y'.
        """
        if not shift:
            return model.InstantaneousFuture()
        self._checkMoveRel(shift)
        shift = self._applyInversion(shift)

        f = self._createFuture()
        f = self._executor.submitf(f, self._doMoveRel, f, shift)
        return f

    def _doMoveAbs(self, future, pos):
        self._moveTo(future, pos)

    @isasync
    def moveAbs(self, pos):
        """
        Move the stage the given position in meters. This is non-blocking.
        Throws an error when the requested position is out of range.

        Parameters
        ----------
        pos: dict(string->float)
            Absolute position to move the stage to per axes in m. Axes are 'x' and 'y'.
        """
        if not pos:
            return model.InstantaneousFuture()
        self._checkMoveAbs(pos)
        pos = self._applyInversion(pos)

        f = self._createFuture()
        f = self._executor.submitf(f, self._doMoveAbs, f, pos)
        return f

    def stop(self, axes=None):
        """Stop the movement of the stage."""
        self._executor.cancel()
        self.parent.stop_stage_movement()
        try:
            self._updatePosition()
        except Exception:
            logging.exception("Unexpected failure when updating position")

    def _createFuture(self):
        """
        Return (CancellableFuture): a future that can be used to manage a move
        """
        f = CancellableFuture()
        f._moving_lock = threading.Lock()  # taken while moving
        f._must_stop = threading.Event()  # cancel of the current future requested
        f._was_stopped = False  # if cancel was successful
        f.task_canceller = self._cancelCurrentMove
        return f

    def _cancelCurrentMove(self, future):
        """
        Cancels the current move (both absolute or relative). Non-blocking.
        future (Future): the future to stop. Unused, only one future must be
         running at a time.
        return (bool): True if it successfully cancelled (stopped) the move.
        """
        # The difficulty is to synchronise correctly when:
        #  * the task is just starting (not finished requesting axes to move)
        #  * the task is finishing (about to say that it finished successfully)
        logging.debug("Cancelling current move")
        future._must_stop.set()  # tell the thread taking care of the move it's over
        self.parent.stop_stage_movement()

        with future._moving_lock:
            if not future._was_stopped:
                logging.debug("Cancelling failed")
            return future._was_stopped


class Focus(model.Actuator):
    """
    This is an extension of the model.Actuator class. It provides functions for
    moving the SEM focus (as it's considered an axis in Odemis)
    """

    def __init__(self, name, role, parent, **kwargs):
        """
        axes (set of string): names of the axes
        """

        fwd_info = parent.fwd_info()
        axes_def = {
            "z": model.Axis(unit=fwd_info["unit"], range=fwd_info["range"]),
        }

        model.Actuator.__init__(self, name, role, parent=parent, axes=axes_def, **kwargs)

        # will take care of executing axis move asynchronously
        self._executor = CancellableThreadPoolExecutor(max_workers=1)  # one task at a time

        # RO, as to modify it the server must use .moveRel() or .moveAbs()
        self.position = model.VigilantAttribute({}, unit="m", readonly=True)
        self._updatePosition()

        # Refresh regularly the position
        self._pos_poll = util.RepeatingTimer(5, self._refreshPosition, "Position polling")
        self._pos_poll.start()

    def _updatePosition(self):
        """
        update the position VA
        """
        z = self.parent.get_free_working_distance()
        self.position._set_value({"z": z}, force_write=True)

    def _refreshPosition(self):
        """
        Called regularly to update the current position
        """
        # We don't use the VA setters, to avoid sending back to the hardware a
        # set request
        logging.debug("Updating SEM stage position")
        try:
            self._updatePosition()
        except Exception:
            logging.exception("Unexpected failure when updating position")

    def _doMoveRel(self, foc):
        """
        move by foc
        foc (float): relative change in m
        """
        try:
            foc += self.parent.get_free_working_distance()
            self.parent.set_free_working_distance(foc)
        finally:
            # Update the position, even if the move didn't entirely succeed
            self._updatePosition()

    def _doMoveAbs(self, foc):
        """
        move to pos
        foc (float): unit m
        """
        try:
            self.parent.set_free_working_distance(foc)
        finally:
            # Update the position, even if the move didn't entirely succeed
            self._updatePosition()

    @isasync
    def moveRel(self, shift):
        """
        shift (dict): shift in m
        """
        if not shift:
            return model.InstantaneousFuture()
        self._checkMoveRel(shift)

        foc = shift["z"]
        f = self._executor.submit(self._doMoveRel, foc)
        return f

    @isasync
    def moveAbs(self, pos):
        """
        pos (dict): pos in m
        """
        if not pos:
            return model.InstantaneousFuture()
        self._checkMoveAbs(pos)

        foc = pos["z"]
        f = self._executor.submit(self._doMoveAbs, foc)
        return f

    def stop(self, axes=None):
        """
        Stop the last command
        """
        # Empty the queue (and already stop the stage if a future is running)
        self._executor.cancel()
        logging.debug("Cancelled all ebeam focus moves")

        try:
            self._updatePosition()
        except Exception:
            logging.exception("Unexpected failure when updating position")
