import numpy as np
import matplotlib.pyplot as plt
import time
import pyvisa as visa
import threading
import xmlrpc
from xmlrpc.client import ServerProxy
import cantrips as can
import sys
import datetime
import astropy.io.fits as fits
from astropy.table import Table
import os
from calculate_npulses_PDSM05PD3A_NewPD import calculate_npulses


KEITHLEY_DISCHARGE_TIME = 2

def keithley_integration_loop(kei=None, nreads=1):
    kei.data = []
    while(kei._integrate):
        kei.data.append(kei.getread(nreads))

def spectro_integration_loop(spectro, spectro_exptime=1):
    spectro.spectra = []
    while(spectro._integrate):
        spectro.spectra.append(spectro.get_spectrograph(int(spectro_exptime*1e6)))

def non_blocking_integration_loop(kei, spectro=None, nreads=1, spectro_exptime=1):
    kei.p = threading.Thread(target=keithley_integration_loop, kwargs={'kei':kei, 'nreads':nreads})
    kei._integrate=True
    kei.p.start()
    print ('spectro = ' + str(spectro)) 
    if spectro is not None:
        spectro.p  = threading.Thread(target=spectro_integration_loop, kwargs={'spectro':spectro, 'exptime':spectro_exptime})

def keithley_join(kei, spectro=None):
    kei._integrate=False
    kei.p.join()
    data = np.array(kei.data).T
    return np.rec.fromarrays(data[[0,1]], names=['charges', 'time'])
    
def spectro_join(spectro):
    spectro._integrate = False
    spectro.p.join()
    return np.rec.fromarrays([self.spectra[0][0]]+[s[1] for s in spectro.spectra], names=['wl'] + [f'flux{n}' for n in range(len(spectro.spectra))])

def before_exposure(kei, spectro=None):
    keithley_pwc = 1
    keithley_discharge_time = KEITHLEY_DISCHARGE_TIME
    kei.set_charge_mode(keithley_pwc)
    kei.zero_check(True)
    time.sleep(keithley_discharge_time)
    kei.zero_check(False)

delta_filtre = -45
phase_filtre = 100
number_filtre = 4
POS_0 = 'EMPTY'
POS_1 = 'red532'
POS_2 = 'blue680'
POS_3 = '1064'
l_POS = ['EMPTY', 'red532', 'blue680', '1064']


class LaserWheel():

     def __init__(self):
         self.band_names = {}
         self.number_filtre = number_filtre
         self.phase_filtre = phase_filtre
         self.delta_filtre = delta_filtre
         for i in range(0,self.number_filtre):
             self.band_names[i] = l_POS[i]
         self.band_pos = dict(list(zip(list(self.band_names.values()),
list(self.band_names.keys()))))
         self.get_filter()
         #self.laserwheel_proxy = laserwheel_proxy
         
     def angle_to_slot(self, angle):
         return round((angle - self.phase_filtre)/self.delta_filtre)

     def slot_to_angle(self, slot):
         return self.phase_filtre + slot * self.delta_filtre

     def get_filter(self):
         self._filter = self.band_names[self.angle_to_slot(laserwheel_proxy.get_position())]
         return self._filter

     def set_filter(self, filtername):
         if self._filter != filtername:
             filterpos = self.slot_to_angle(self.band_pos[filtername])
             laserwheel_proxy.set_position(filterpos)
             self._filter = filtername

big_start_time = time.time() 


laser = ServerProxy('http://134.158.154.15:8014')
#spectro = ServerProxy('http://127.0.0.1:8013')
spectro=None
laserwheel_proxy = ServerProxy('http://134.158.154.15:8015')
laserwheel = LaserWheel() #laserwheel_proxy=laserwheel_proxy)
keithley = ServerProxy('http://134.158.154.15:8012')
before_exposure(keithley)
draw = False #define if you want to plot for every wavelength

laser.set_power_on()
laserwheel.set_filter('EMPTY')
time.sleep(5) 
laser.set_delay_before(1)
laser.set_delay_after(1)

laser.set_mode('Burst')
##- Alternative
# laser.set_mode('Continuous')

#- SZF
#laser.set_energy_level('MAX')
##- Alternative:
laser.set_energy_level('Adjust')
#laser.set_qsw(int(value))
npulses = 1000
laser.set_npulses(npulses)

##- Alternative check
#print(laser.get_mode())
#print(laser.get_energy_level())
#print(laser.get_interlock_state())

##- Set the keithley range:
keithley.set_charge_range_upper(2) # 2 or 20 micro C
#keithley.set_charge_range_lower(20) # 20 or 200 nano C

rm = visa.ResourceManager()
A_keysight_id = 'USB0::2391::37912::MY54321261::0::INSTR'
A_key = rm.open_resource(A_keysight_id)
print ('Hello?')
run_name = input('Enter the string that should identify this run (NO SPACES PLEASE; example: QSW_sequence_for_lin_test):')
extra_end_str = input('Enter the string that will be added as a suffix to every file name (no spaces please; typically this is nothing, so just hit [Enter]):') 
aperture_size = input('What is the aperture between integrating sphere and CBP? (examples: empty, 5mm, 2mm, 75um, ...)')
print ('aperture_size = ' + str(aperture_size)) 

A_key.timeout = 1000000 #Keysight timeout in milliseconds - needs to be high if n_pl_cycles is high

#Set charge mode and range

charge_mode = True
A_key.write(':SENSe:FUNCtion:ON "CHAR"')
charge_range_str = '2e-6' #charge range in coloumbs - max value is 2e-6 
A_key.write(":SENS:CHAR:RANG:AUTO OFF")
A_key.write(":SENS:CHAR:RANG " + charge_range_str)

#Set current mode and range

current_mode = False
#A_key.write(':SENSe:FUNCtion:ON "CURR"')
#curr_range_str = '1E-7'
#A_key.write(":SENS:CURR:RANG:AUTO OFF")
#A_key.write(":SENS:CURR:RANG " + curr_range_str)

#Set measurement speed
n_pl_cycles_str = '100'
pl_freq = 1/50.0 
#A_key.write(":SENS:CHAR:NPLC:AUTO OFF")
#A_key.write(":SENS:CHAR:NPLC " + n_pl_cycles_str)
A_key.write(":SENS:CURR:NPLC:AUTO OFF")
A_key.write(":SENS:CURR:NPLC " + n_pl_cycles_str)

#Adjust trigger timing parameters
trigger_delay_time_str = '0'
A_key.write(':TRIG:ACQ:DEL ' + trigger_delay_time_str) 
A_key.write(':TRIG:SOUR TIM')
trigger_time_interval_str = '2e-3'
A_key.write(':TRIG:TIM ' + trigger_time_interval_str)
n_samples_str = '15000'
A_key.write(':TRIG:COUN ' + n_samples_str)

############# PARAMETERS TO MODIFY #########################

#Set the list of wavelengths
l_wavelength = np.arange(350, 1050, 5)
#l_wavelength = [500]

n_w_loops = 10
l_wavelength = can.flattenListOfLists([l_wavelength for i in range(n_w_loops)] )

#Set the number of bursts
n_bursts = 300 # normally 5 

#Set the list of QSWs
l_qsw = np.arange(285, 301, 1)
l_qsw = l_qsw.tolist()
l_qsw.reverse()
l_qsw = ['Max'] +  [str(elem) for elem in l_qsw]
l_qsw = ['Max']

#Set the list of filters
l_filter = ['EMPTY', 'red532', 'blue680', '1064']
l_filter = ['EMPTY']

today = datetime.date.today() 
date_str = str(today.year) + ('0' if today.month < 10 else '') + str(today.month) + ('0' if today.day < 10 else '') + str(today.day)
tally_file = 'tally.txt'
readme_file_root = 'README' 

#Set the names of directory and files

extra_dir_str = input('Enter subdirectory for data run (under ~/' + 'ut' + date_str + '/)' + '  Empty return will create the following subdirectory : '+  run_name + ': ')
if len(extra_dir_str) > 0:
    extra_dir_str = extra_dir_str + '/'
elif len(run_name) > 0:
    extra_dir_str = run_name + '/' 
else: 
    extra_dir_str = ''
dir_root = './'
#if not(os.path.exists(dir_root + 'ut' + date_str + '/' + csv_exp_dir + '/')):
#    if not(os.path.exists(dir_root + 'ut' + date_str + '/')):
#        os.mkdir(dir_root + 'ut' + date_str + '/')
#    os.mkdir(dir_root + 'ut' + date_str + '/' + csv_exp_dir + '/')

if not(os.path.exists(dir_root + 'ut' + date_str + '/' + extra_dir_str)):
    if not(os.path.exists(dir_root + 'ut' + date_str + '/')):
        os.mkdir(dir_root + 'ut' + date_str + '/')
    if len(extra_dir_str) > 0: 
        os.mkdir(dir_root + 'ut' + date_str + '/' + extra_dir_str)
save_dir = 'ut' + date_str + ('/' + extra_dir_str if len(extra_dir_str) > 0 else '')
if not(os.path.exists(dir_root + save_dir + tally_file)):
    tally = open(dir_root + save_dir + tally_file, 'w')
    tally.write('0')
    tally.close() 

#############################################################

start = time.time()
date = datetime.datetime.now().isoformat()
expnum = 1
#f_readme = open('README_' + run_name + '.txt', 'w')
tally = open(os.path.join(f'{dir_root}', f'{save_dir}', f'{tally_file}'),'r') 
current_tally_num = int(tally.read()) + 1
tally.close() 
readme_lines = ['Read me file for run titled: ' + run_name + '\n',
                'Sequence span: ' + str(current_tally_num) + ' TO ' , # I purposely leave off the '\n' here - it's added later once I know what my terminating integer is 
                'lasernburst:' + str( n_bursts) + '\n',
                'DATE-OBS: ' + str( date)  + '\n',
                'darktime: ' + str( KEITHLEY_DISCHARGE_TIME ) + '\n',
                'KEY_GAIN: ' + str( charge_range_str ) + '\n',
                'APERTURE: ' + str( aperture_size ) + '\n',
                'ADDITIONAL NOTES: ' + '\n', 
                ]
print ('The current readme file for this run includes the following lines: ' )
for line in readme_lines:
    print (line)
still_adding_lines = 1
while still_adding_lines: 
    additional_line = input('Please enter additional lines to add to readme, one at a time.  An empty string (just [Enter]) will terminate the sequence: ')
    if len(additional_line) > 0:
        readme_lines = readme_lines + [additional_line + '\n']
    else:
        still_adding_lines = 0 



tally.close() 

for filt in l_filter :

    laserwheel.set_filter(filt)
    time.sleep(1)
    print('Filter : ', laserwheel.get_filter())
    
    for qsw in l_qsw :
        
        if qsw == 'Max':
            laser.set_energy_level('MAX')
            time.sleep(10)
            print(f'QSW : {laser.get_energy_level()}')
        else:  
            laser.set_qsw(int(qsw))
            time.sleep(10)
            print(f'QSW : {laser.get_qsw()}')
        

        for k in range(len(l_wavelength)):
            wl = l_wavelength[k] 
            iteration_start = time.time()
            laser.set_delay_before(1)
            laser.set_delay_after(1)
            n_bursts = 5
            npulses, spectro_exptime = calculate_npulses(wl)
            npulses_lim = 200

            if npulses > npulses_lim :

                n_bursts = int((npulses/npulses_lim)*n_bursts)
                npulses = npulses_lim
                laser.set_delay_before(0.5)#round(npulses_lim/npulses, 1))
                laser.set_delay_after(0.5)#round(npulses_lim/npulses, 1))

            print('npulses = ' + str(npulses))
            print('spectro_exptime = '+str(spectro_exptime))
            print('wavelength = ' + str(wl))
            laser.set_npulses(npulses)
            burst_starts = [-1 for i in range(n_bursts)]
            burst_ends = [-1 for i in range(n_bursts)]
            laser.set_wavelength(f'{int(wl)}')
            before_exposure(keithley)
            A_key.write("SENS:CHAR:DISCharge") 
            #Tell Keysight to start taking data 
            non_blocking_integration_loop(keithley, spectro=None, nreads=1, spectro_exptime=spectro_exptime)

            print('starting Keithley acquisition ')
            A_key.write(':INP ON')
            A_key.write(':INIT:ACQ')
            print('starting Keysight acquisition ')
            for i in range(n_bursts) :
                start_burst = time.time()  
                burst_starts[i] = start_burst 
                laser.trigger_burst()
                end_burst = time.time()
                burst_ends[i] = end_burst 
                print ('Burst ' + str(i+1) + ' of ' + str(n_bursts) + ' took ' + str(end_burst - start_burst) + 's.')
                #A_key.write("SENS:CHAR:DISCharge")
                #print ('Discharge?')
                #time.sleep(0.2) 
                
            #print ('Burst starts ='  +str(burst_starts))
            print ('About to ask for data...') 
            results_str = A_key.query(':FETC:ARR:CHAR?')
            print ('Finished asking for data...') 
            #print('end query arr', time.time()-start)
            kei_data = keithley_join(keithley)[0]
            #print ('kei_data = ' + str(kei_data)) 
            #spectro_data = spectro_join(spectro)
            #print ('spectro_data = ' + str(spectro_data)) 
            #kei_data = kei_data.T
            kei_data= [[elem[1] for elem in kei_data], [elem[0] for elem in kei_data]]

            print(f'Charge in keithley :{kei_data[1][-1]}')
            #print(spectro_data)
            
            #print ('results_str = ' + str(results_str))
            results_data = results_str.split(',')
            results_data[-1] = results_data[-1][:-1] 
            #print ('results_data = ' + str(results_data)) 
            results_data = [float(elem) for elem in results_data]

            data_point_time_sep = float(trigger_time_interval_str) + float(n_pl_cycles_str) * pl_freq
            delta_ts = [data_point_time_sep * i for i in range(int(n_samples_str))]
            #saveListsToColumns(lists_to_save, save_file, save_dir, sep = ' ', append = False, header = None, type_casts = None)

            SC_data = [can.round_to_n(elem, 5) for elem in results_data]
            SC_time = [can.round_to_n(elem, 5) for elem in delta_ts]
            PD_data = kei_data[1]
            PD_time = kei_data[0]

            ####################### SAVE TO CSV FILES ################################
            #can.saveListsToColumns([[can.round_to_n(elem, 5) for elem in delta_ts], [can.round_to_n(elem, 5) for elem in results_data]], dir_root + f'ut' + date_str + '/' + exp_dir + f'/SolarCell_fromB2987A_Iter{int(k):04d}_Wave{int(wl):04d}_QSW{qsw}_Filter{filt}' + extra_end_str + '.csv', '', sep = ', ', header = 'Time sep (ms?), Charge (C)')
            #can.saveListsToColumns([kei_data[0], kei_data[1]], dir_root + f'ut' + date_str + '/' + exp_dir + f'/Photodiode_fromKeithley_Iter{int(k):04d}_Wave{int(wl):04d}_QSW{qsw}_Filter{filt}' + extra_end_str + '.csv', '', sep = ', ', header = 'Time sep (s?), Charge (C)')
            #can.saveListsToColumns([burst_starts, burst_ends], dir_root + f'ut' + date_str + '/' + exp_dir + f'/LaserBurstTimes_Iter{int(k):04d}_Wave{int(wl):04d}_QSW{qsw}_Filter{filt}' + extra_end_str + '.csv', '', sep = ', ', header = 'Burst starts (s), Burst ends(A)')

            ####################### SAVE TO FITS FILE #################################

            ### Save the data from the solar cell ###
            
            col_SC_data = fits.Column(name='charge', format='E', array=SC_data)
            col_SC_time = fits.Column(name='time', format='E', array=SC_time)
            cols_SC = fits.ColDefs([col_SC_time, col_SC_data])
            hdu1 = fits.BinTableHDU.from_columns(cols_SC)
            hdu1.header['EXTNAME'] = 'SOLARCELL'

            ### Save the data from the photodiode ###
            
            col_PD_data = fits.Column(name='charge', format='E', array=PD_data)
            col_PD_time = fits.Column(name='time', format='E', array=PD_time)
            cols_PD = fits.ColDefs([col_PD_time, col_PD_data])
            hdu2 = fits.BinTableHDU.from_columns(cols_PD)
            hdu2.header['EXTNAME'] = 'KEITHLEY'

            ### Save all the metadata ###

            hdr = fits.Header()
            hdr['lasernpulses'] = npulses
            hdr['laserwavelength'] = wl
            hdr['laserwheelfilter'] = filt
            hdr['lasernburst'] = n_bursts
            hdr['LASERQSW'] = qsw
            hdr['DATE-OBS'] = date
            hdr['pinhole'] = aperture_size 
            hdr['EXPNUM'] = expnum
            hdr['darktime'] = KEITHLEY_DISCHARGE_TIME

            if charge_mode :
                hdr['MODE'] = 'charge'
                hdr['KEY_GAIN'] = charge_range_str

            if current_mode :
                hdr['MODE'] = 'current'
                hdr['KEY_GAIN'] = curr_range_str
                
            hdr['APERTURE'] = aperture_size
            empty_hdu = fits.PrimaryHDU(header=hdr)

            ### Write the .fits file ###

            hdul = fits.HDUList([empty_hdu, hdu1, hdu2])            
            #fits_savepath = dir_root + f'ut' + date_str + '/' + exp_dir + f'SC_{expnum:07d}' + extra_end_str
            tally = open(os.path.join(f'{dir_root}', f'{save_dir}', f'{tally_file}'),'r') 
            current_tally_num = int(tally.read()) + 1
            tally.close() 
            print ('current_tally_num = ' + str(current_tally_num)) 
            fits_savepath = os.path.join(f'{dir_root}', f'{save_dir}',  f'SC_{current_tally_num:07d}{extra_end_str}.fits')
            tally = open(os.path.join(f'{dir_root}', f'{save_dir}', f'{tally_file}'),'w') 
            tally.write(str(current_tally_num ))
            tally.close()
            hdul.writeto(fits_savepath, overwrite=True)
            

            iteration_end = time.time()
            expnum += 1
            print ('One round through loop took ' + str(iteration_end - iteration_start) + 's')

            if draw == True :
                fig, ax = plt.subplots(2, 1)
                ax[1].plot(delta_ts, results_data)
                ax[1].scatter(delta_ts, results_data, marker = '+')
                ax[1].set_xlabel(r'$\Delta t$ (ms)')
                ax[1].set_ylabel(r'SC Charge (C)')
                ax[0].plot(kei_data[0], kei_data[1], '+') 
                ax[0].set_xlabel(r'PD time (s)')
                ax[0].set_ylabel(r'PD Charge (C)')
                plt.draw()
                plt.pause(0.05)
                plt.close()
            #Uncomment below if you want to plot the current of every measurement. 
            """ 
            fig, ax = plt.subplots(2, 1)
            ax[1].plot(delta_ts, results_data)
            ax[1].scatter(delta_ts, results_data, marker = '+')
            ax[1].set_xlabel(r'$\Delta t$ (ms)')
            ax[1].set_ylabel(r'SC Charge (C)')
            ax[0].plot(kei_data[0], kei_data[1], '+') 
            ax[0].set_xlabel(r'PD time (s)')
            ax[0].set_ylabel(r'PD Charge (C)')
            plt.show()
            """ 

        if qsw == 'Max':
            laser.set_energy_level('Adjust')
readme_lines[1] = readme_lines[1] + str(current_tally_num) + '\n'
print ('[file for file in dir_root + save_dir] = ' + str([file for file in os.listdir(dir_root + save_dir)])) 
print ("readme_file_root + '_' + run_name = " + str(readme_file_root + '_' + run_name)) 
n_existing_readmes = len([file for file in os.listdir(dir_root + save_dir) if file.startswith(readme_file_root + '_' + run_name)])

print ('n_existing_readmes = ' + str(n_existing_readmes)) 
f_readme = open(dir_root + save_dir +  readme_file_root + '_' + run_name + '_' + str(n_existing_readmes) + '.txt', 'w')
f_readme.writelines(readme_lines)
f_readme.close() 

A_key.write(':INP OFF')
#print ('kei_data = ' + str(kei_data)) 
big_end_time  = time.time()
print ('Full sequence took ' + str(big_end_time - big_start_time) + 's')  
fig, ax = plt.subplots(2, 1)
ax[1].plot(delta_ts, results_data)
ax[1].scatter(delta_ts, results_data, marker = '+')
ax[1].set_xlabel(r'$\Delta t$ (ms)')
ax[1].set_ylabel(r'SC Charge (C)')
ax[0].plot(kei_data[0], kei_data[1], '+') 
ax[0].set_xlabel(r'PD time (s)')
ax[0].set_ylabel(r'PD Charge (C)')
plt.show()

##- Remember to turn off the laser
laser.set_power_off()

