#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
Created on 15 October 2018

@author: Anders Muskens

Testing script for determining the curve fit for analyze-shifts.py TSV files.

'''

from __future__ import division, print_function

import sys
import argparse
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats
from scipy.optimize import curve_fit
from mpl_toolkits.mplot3d import Axes3D

import csv
import collections


def abline(slope, intercept):
    """Plot a line from slope and intercept"""
    axes = plt.gca()
    x_vals = np.array(axes.get_xlim())
    y_vals = intercept + slope * x_vals
    plt.plot(x_vals, y_vals, '--')


def exp_func(x, a, b, c):
    return a * np.exp(-b * x) + c


def quad_func(x, a, b, c):
    return a * (x - b) ** 2 + c


def cubic_func(x, a, b, c):
    return a * (x - b) ** 3 + c


def lin_func(x, a, b):
    return a * x + b


def log_func(x, a, b, c):
    return a * np.log(x - b) + c


def quart_func(x, a, b, c, d):
    return a * (b * (x - c)) ** 4 + d


def arctan_func(x, a, b, c, d):
    return a * np.arctan(b * x) + c*x + d


def recip_func(x, a, b, c, d):
    return a * (1 / (b * (x - c))) + d


def main(args):

    # arguments handling
    parser = argparse.ArgumentParser(description="Test curve-fitting strategies on analyze-shifts data. ")

    parser.add_argument(dest="filenames", nargs="+",
                        help="filenames of the TSV tables generated by analyze_shifts.py")
    options = parser.parse_args(args[1:])

    filenames = options.filenames
    
    data = collections.defaultdict(dict)  # res -> zoom -> td > s

    for filename in filenames:
        with open(filename) as csvfile:
            reader = csv.DictReader(csvfile, delimiter='\t',)
            for row in reader:

                zoom = float(row['zoom'])
                res = int(row['res X'])
                td = float(row['dwell time (s)'])
                s = float(row['shift X (base px)'])

                try:
                    data[res][zoom][td] = s
                except KeyError:
                    if not res in data:
                        data[res] = {}
                    if not zoom in data[res]:
                        data[res][zoom] = {}
                    data[res][zoom][td] = s

    calib = {}

    for res in data.keys():
        for zoom in data[res].keys():
            td = sorted(np.array(list(data[res][zoom].keys())))
            s = np.array([data[res][zoom][x] for x in td])
            # plt.scatter(td, s)
            # plt.show()
            try:
                popt, pcov = curve_fit(arctan_func, td, s)
            except (RuntimeError, TypeError):
                continue

            xdata = np.linspace(0.00000096, 0.00004, 100)
            print(zoom)
            plt.plot(xdata, arctan_func(xdata, *popt))
            plt.plot(td, s, 'r-')
            plt.show()

            try:
                calib[res][zoom] = popt
            except KeyError:
                if not res in calib:
                    calib[res] = {}
                if not zoom in calib[res]:
                    calib[res][zoom] = {}
                calib[res][zoom] = popt

    # slope_r, intercept_r, r, p, stderr = scipy.stats.linregress(z, s)

    """
    n = 50

    
    xdata = np.linspace(1, 50, n)
    ydata = np.linspace(0.00000096, 0.00004, n)
    zdata = np.empty((n, n))
    for x in range(len(xdata)):
        for y in range(len(ydata)):
            print x, y
            zdata[x, y] = s_of_z_td(xdata[x], ydata[y], properties)

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_surface(xdata, ydata, zdata)
    plt.show()
    """

    """
    popt, pcov = curve_fit(lin_func, z[:2], s[:2])
    popt2, pcov2 = curve_fit(cubic_func, z[1:8], s[1:8])
    popt3, pcov3 = curve_fit(quad_func, z[7:], s[7:])
    popt4, pcov4 = curve_fit(exp_func, z[2:], s[2:])

    plt.plot(z, s);
    # plt.plot(z[:2], lin_func(z[:2], *popt), 'r-')
    # plt.plot(z[1:8], quad_func(z[1:8], *popt2), 'r-')
    # plt.plot(z[7:], quart_func(z[7:], *popt3), 'r-')

    xdata = np.linspace(1, 50, 100)
    ydata = s_of_z(xdata, popt, popt2, popt3)
    # plt.plot(xdata, ydata, 'g-')

    ydata2 = s_of_z2(xdata, popt, popt4)
    plt.plot(xdata, ydata2, 'r-')
    # abline(slope_r, intercept_r)
    plt.show()
    """


def s_of_z_td(zoom, td, calib):
    """
    
    calib: dict of float -> [a, b, c] where the float is a Zoom level
    """
    
    try:
        popt = calib[zoom]
        return arctan_func(td, *popt)
    except KeyError:
        # No zoom for this position. Interpolate.
        zooms = sorted(calib.keys())

        if zoom <= 2:
            z1 = [z for z in zooms if z <= 2]
            s_of_td1 = [arctan_func(td, *calib[z]) for z in z1]
            popt1, pcov = curve_fit(lin_func, z1, s_of_td1)
            return lin_func(zoom, *popt1)

        if td > 0.00000192:
            if 2 < zoom < 20:
                z2 = [z for z in zooms if 2 <= z <= 20]
                s_of_td2 = [arctan_func(td, *calib[z]) for z in z2]
                popt2, pcov = curve_fit(arctan_func, z2, s_of_td2)
                return arctan_func(zoom, *popt2)
            elif zoom > 20:
                z3 = [z for z in zooms if z > 20]
                s_of_td3 = [arctan_func(td, *calib[z]) for z in z3]
                popt3, pcov = curve_fit(quad_func, z3, s_of_td3)
                return quad_func(zoom, *popt3)
        else:
            if 2 < zoom:
                z2 = [z for z in zooms if 2 <= z]
                s_of_td2 = [arctan_func(td, *calib[z]) for z in z2]
                popt2, pcov = curve_fit(arctan_func, z2, s_of_td2)
                return arctan_func(zoom, *popt2)
        """
        elif zoom > 30:
            z4 = [z for z in zooms if z > 30]
            s_of_td4 = [arctan_func(td, *calib[z]) for z in z4]
            popt4, pcov = curve_fit(exp_func, z4, s_of_td4)
            return exp_func(zoom, *popt4)
        """
        """
        xdata = np.linspace(1, 50, 50)
        ydata = s_of_z(xdata, popt1, popt2, popt3, popt4)
        plt.plot(xdata, ydata, 'r-')
        plt.plot(zooms, s_of_td1 + s_of_td2 + s_of_td3 + s_of_td4, 'g-')
        plt.show()
        """

        return None  # code never reached


if __name__ == '__main__':
    main(sys.argv)
