import astropy.io.fits as pyfits
import binascii
import cv2
import datetime
import glob
import numpy as np
#import pyds9
import serial
import struct
import threading
import time
from types import MethodType

_status_codes = [b'\x00', 'CAM_OK', 'Function executed',
                 b'\x01', 'CAM_BUSY', 'Camera busy processing serial command',
                 b'\x02', 'CAM_NOT_READY', 'Camera not ready to execute specified serial command',
                 b'\x03', 'CAM_RANGE_ERROR', 'Data out of range',
                 b'\x04', 'CAM_CHECKSUM_ERROR', 'Header or message-body checksum error',
                 b'\x05', 'CAM_UNDEFINED_PROCESS_ERROR', 'Unknown process code',
                 b'\x06', 'CAM_UNDEFINED_FUNCTION_ERROR', 'Unknown function code',
                 b'\x07', 'CAM_TIMEOUT_ERROR', 'Timeout executing serial command',
                 b'\x09', 'CAM_BYTE_COUNT_ERROR', 'Byte count incorrect for the function code',
                 b'\x0A', 'CAM_FEATURE_NOT_ENABLED', 'Function code not enabled in the current configuration']
status_codes = dict(zip(_status_codes[1::3], _status_codes[0::3]))
status_map = dict(zip(_status_codes[0::3], _status_codes[1::3]))
status_doc = dict(zip(_status_codes[0::3], _status_codes[2::3]))

def bitconv(B):
    B0 = (B[0::7] << 6) + (B[1::7] >> 2)
    B1 = ((B[1::7] & 0x03) << 12) + (B[2::7] << 4) + (B[3::7] >> 4)
    B2 = ((B[3::7] & 0x0F) << 10) + (B[4::7] << 2) + (B[5::7] >> 6)
    B3 = ((B[5::7] & 0x3F) << 8) + B[6::7]
    BP = np.zeros(len(B) // 14 * 8)
    BP[::4] = B0
    BP[1::4] = B1
    BP[2::4] = B2
    BP[3::4] = B3
    return BP

class Flir(object):
    def __init__(self, port='/dev/ttyUSB0'):
        self.serial = serial.Serial(port=port, baudrate=57600)
        #self._commands = NTuple.fromorg('command_byte.org')
        self._commands = np.load('command_byte.npy')
        for c in self._commands:
            setattr(self, c['Command'].lower(), self._command_factory(c))
        self.sensorid = {'INSTRUM': 'STARDICE_IRT0',
                         'MANUFACT': 'FLIR',
                         'CAMPART': b''.join(self.camera_part()).replace(b'\x00', b'').decode(),
                         'SERNUM': self.serial_number()[0],
                         'SENSNUM': self.serial_number()[1],
                         'FIRMWARE': '%d.%d'%(self.get_revision()[:2]),
                         'SOFTWARE': '%d.%d'%(self.get_revision()[2:])}
        l = glob.glob('IR*.fits')
        l.sort()
        if l:
            self.odo = int(l[-1][2:9]) + 1
        else:
            self.odo = 0

    def _command_factory(self, c):
        if c['ENUM']:
            enum =  [e.strip() for e in c['ENUM'].split(',')]
        code = int(c['Code'], 0)
        def f(self, *args, full=False):
            if c['BCC']:
                if c['ENUM'] != 'nan':
                    data = struct.pack(c['ArgumentC'], enum.index(args[0]))
                elif c['RANGE'] != 'nan':
                    r = [int(_r) for _r in c['RANGE'].split('-')]
                    if args[0] > r[1] or args[0] < r[0]:
                        raise ValueError(f'Argument should be in range {c["RANGE"]}')
                    data = struct.pack(c['ArgumentC'], args[0])
                else:
                    data = struct.pack(c['ArgumentC'], *args)
            else:
                data = b''
            msg = self.write(code,data)
            answer = self.read()
            if answer['status'] != 'CAM_OK':
                raise RuntimeError(f'{answer["status"]}: {status_doc[answer["status"]]}')
            if c['ArgumentR'] != 'nan':
                answer['data'] = struct.unpack(c['ArgumentR'], answer['data'])
            if c['ENUM'] != 'nan':
                try:
                    answer['data'] = enum[answer['data'][0]]
                except IndexError:
                    pass
            if full:
                return answer
            else:
                if len(answer['data']) == 1:
                    return answer['data'][0]
                else:
                    return answer['data']
        doc = c['Description']
        if c['ENUM'] != 'nan':
            doc+= f'\n{c["ENUM"]}'
        if c['Notes'] != 'nan':
            doc+= f'\n{c["Notes"]}'
        f.__doc__= doc
        return MethodType(f, self)
    
    def write(self, function, data):
        s = struct.Struct('>cccBH')
        process_code = b'\x6E'
        status = status_codes['CAM_OK']
        reserved = b'\x00'
        byte_count = len(data)
        header = s.pack(process_code, status, reserved,function, byte_count)
        crc1 = struct.pack('>H', binascii.crc_hqx(header, 0))
        crc2 = struct.pack('>H', binascii.crc_hqx(header+crc1+data, 0))
        msg =  header + crc1 + data + crc2
        self.serial.write(msg)
        return msg

    def read(self):
        msg = self.serial.read(8)
        answer = dict(zip(['process_code', 'status', 'reserved', 'function', 'byte_count', 'crc1'], struct.unpack('>ccccHH', msg)))
        msg = self.serial.read(answer['byte_count'] + 2)
        answer['data'] = msg[:-2]
        answer['crc2'] = struct.unpack('>H', msg[-2:])
        answer['status'] = status_map[answer['status']]
        return answer

    def video_display(self, video_device=2):
        d = pyds9.DS9()
        video = cv2.VideoCapture(video_device)
        self._video_on = True
        def _run():
            while self._video_on:
                success, self._image = video.read()
                d.set_np2arr(self._image[:,:,0])
        import threading
        self._video_thread = threading.Thread(target=_run)
        self._video_thread.start()
        
    def video_stop(self):
        self._video_on = False
        self._video_thread.join()

    def capture(self):
        self.set_video_mode('FREEZE')
        keys = {'SPOT': self.get_spot_meter_data(),
                'SENSORC': self.read_sensor(0) /10.,
                'SENSORR': self.read_sensor(1),
                'UTCDATE': datetime.datetime.utcnow().isoformat(),
                'BRIGHT': self.get_brightness(),
                'CONTRAST': self.get_contrast()}
        keys.update(self.sensorid)
        imfname = 'IR%07d.fits' % self.odo
        pyfits.writeto(imfname, self._image[:,:,0], header=pyfits.Header(keys))
        print(f'image written: {imfname}')
        self.odo+=1
        self.set_video_mode('REALTIME')

    def read_memory(self, address, length):
        self.write(0xD2, struct.pack('>IH', address, length))
        return self.read()

    def read_snapshot(self, snap):
        address, nbytes = f.get_nuc_address(snap, 0x13)
        n256 = nbytes // 256
        rem = nbytes % 256
        res = b''
        for i in range(n256):
            print(f'{i}/{n256}')
            res += self.read_memory(address+i*256, 256)
        res += self.read_memory(address+n256*256, rem)
        return np.array(struct.unpack(f'>{nbytes}B', res))

    def auto_adjust(self, contrast=64):
        self.set_brightness(self.get_spot_meter_data())
        self.set_contrast(contrast) # 1 for 1


def contrast_to_gain(contrast):
    return contrast/64

def videotodigital(image, levelvalue=140):
    if isinstance(image, str):
        image = pyfits.open(image)
    contrast = image[0].header['CONTRAST']
    brightness = image[0].header['BRIGHT']
    alpha = contrast_to_gain(contrast)
    beta = brightness - levelvalue / alpha
    im64 = image[0].data / alpha + beta
    return im64[19:-17, 39:-35]
    
def test_contrast():
    res = []
    for contrast in range(0, 255):
        delta_b = int(30 * 100/(contrast+1))
        f.set_contrast(contrast)
        b = f.get_spot_meter_data()
        f.set_brightness(b)
        time.sleep(0.5)
        i1 = f._image[100:300,100:300]
        f.set_brightness(b+delta_b)
        time.sleep(0.5)
        i2 = f._image[100:300,100:300]
        res.append([contrast, b, delta_b, i1.mean(), i1.std(), i2.mean(), i2.std()])
    return np.rec.fromrecords(res, names=['contrast', 'brightness', 'delta_b', 'm1', 's1', 'm2', 's2'])


if __name__ == '__main__':
    f = Flir()
    
    f.set_video_lut('WHITEHOT')
    f.set_video_color_mode('MONOCHROME')
    f.set_video_orientation('INVERT')
    f.set_agc_type('MANUAL')
    f.auto_adjust()
    #f.video_display()
    
    #for i in range(360):
    #    time.sleep(10)
    #    f.auto_adjust()
    #    f.capture()
    
    
