#!/usr/bin/env python3


import os
import sys
import pandas
import numpy
import matplotlib
import matplotlib.pyplot as plt

figsize = (5,3)


def make_plot_ellipticity(data, bn=None, formats=None, pixelscale=1.0):

    fig = plt.figure(figsize=figsize, dpi=150)
    ax = fig.add_subplot(111)
    # ax.tight_layout()

    if (pixelscale is None):
        ps = 1.0
        ax.set_xlabel("effective radius [pixels]")
    else:
        ps = pixelscale
        ax.set_xlabel("effective radius [arcsec]")


    # ax.scatter(data['sma']*ps, data['ellipticity'])
    ax.scatter(data['eff_radius'], data['ellipticity'])
    ax.set_xscale('log')

    ax.set_ylabel("ellipticity")

    for f in formats:
        if (bn is None):
            fn = "ellipticity.%s" % (f)
        else:
            fn = "%s__ellipticity.%s" % (bn, f)
        print("saving plot to %s" % (fn))
        fig.savefig(fn, bbox_inches='tight')

    pass





def make_plot_positionangle(data, bn=None, formats=None, pixelscale=1.0):

    fig = plt.figure(figsize=figsize, dpi=150)
    ax = fig.add_subplot(111)
    # ax.tight_layout()

    if (pixelscale is None):
        ps = 1.0
        ax.set_xlabel("effective radius [pixels]")
    else:
        ps = pixelscale
        ax.set_xlabel("effective radius [arcsec]")


    ax.errorbar(data['eff_radius'], data['pa'], yerr=data['pa_err'], fmt="o", marker="o")
    ax.scatter(data['eff_radius'], data['pa'], marker="o")

    ax.set_xscale('log')

    ax.set_ylabel("position angle [degrees]")

    for f in formats:
        if (bn is None):
            fn = "position_angle.%s" % (f)
        else:
            fn = "%s__position_angle.%s" % (bn, f)
        print("saving plot to %s" % (fn))
        fig.savefig(fn, bbox_inches='tight')

    pass

def make_plot_intensity(data, bn=None, formats=None, pixelscale=1.0):

    fig = plt.figure(figsize=figsize, dpi=150)
    ax = fig.add_subplot(111)
    # ax.tight_layout()

    if (pixelscale is None):
        ps = 1.0
        ax.set_xlabel("effective radius [pixels]")
    else:
        ps = pixelscale
        ax.set_xlabel("effective radius [arcsec]")

    eff_radius = data['sma'] * numpy.sqrt(1.-data['ellipticity']) * ps
    ax.errorbar(eff_radius, data['intens_bgsub'], yerr=data['intens_err'], fmt="o", marker="o")
    ax.scatter(eff_radius, data['intens_bgsub'], marker="o")

    ax.set_xscale('log')
    ax.set_yscale('log')

    ax.set_ylabel("position angle [degrees]")

    for f in formats:
        if (bn is None):
            fn = "intensity.%s" % (f)
        else:
            fn = "%s__intensity.%s" % (bn, f)
        print("saving plot to %s" % (fn))
        fig.savefig(fn, bbox_inches='tight')

    pass



def make_plot_surfacebrightness(data, bn=None, formats=None, pixelscale=1.0):

    fig = plt.figure(figsize=figsize, dpi=150)
    ax = fig.add_subplot(111)
    # ax.tight_layout()

    if (pixelscale is None):
        ps = 1.0
        ax.set_xlabel("effective radius [pixels]")
    else:
        ps = pixelscale
        ax.set_xlabel("effective radius [arcsec]")

    ax.errorbar(data['eff_radius'], data['mag_arcsec'], yerr=data['mag_arcsec_err'], fmt="o", marker="o")
    ax.scatter(data['eff_radius'], data['mag_arcsec'], marker="o")

    ax.set_xscale('log')
    ax.set_ylim(ax.get_ylim()[::-1])
    # ax.set_yscale('log')

    ax.set_ylabel(r"WISE-W1 surface brightness [mag arcsec$^{-2}$]")

    for f in formats:
        if (bn is None):
            fn = "surface_brightness.%s" % (f)
        else:
            fn = "%s__surface_brightness.%s" % (bn, f)
        print("saving plot to %s" % (fn))
        fig.savefig(fn, bbox_inches='tight')

    pass


def make_plot_ellipticity_positionangle(data, bn=None, formats=None, pixelscale=1.0):

    fig = plt.figure(figsize=figsize, dpi=150)
    ax = fig.add_subplot(111)
    # ax.tight_layout()

    if (pixelscale is None):
        ps = 1.0
        ax.set_xlabel("effective radius [pixels]")
    else:
        ps = pixelscale
        ax.set_xlabel("effective radius [arcsec]")

    ax.set_xscale('log')
    ax.set_xlim((6,205))
    # ax.set_xticks([7,10,20,40,70,100,200])

    # ax.scatter(data['sma']*ps, data['ellipticity'])
    ax.scatter(data['eff_radius'], data['ellipticity'], color='blue', s=4)
    ax.plot(data['eff_radius'], data['ellipticity'], color='blue', alpha=0.1)
    ax.set_ylabel("ellipticity", color='blue')
    ax.set_ylim((0, 1.1*numpy.max(data['ellipticity'])))
    # ax.spines['left'].set_color('blue')
    ax.tick_params(axis='y', colors='blue')

    ax2 = ax.twinx()
    # ax2.errorbar(data['eff_radius'], data['pa'], yerr=data['pa_err'], fmt="o", marker="o", color='green')
    ax2.scatter(data['eff_radius'], data['pa'], marker="o", color='green', s=4)
    ax2.plot(data['eff_radius'], data['pa'], color='green', alpha=0.1)
    ax2.set_ylim((-5, 40))
#    ax2.set_xscale('log')
    ax2.set_ylabel("position angle [degrees]", color='green')
    ax2.tick_params(axis='y', colors='green')

    # find peak ellipticity
    max_e = numpy.argmax(data['ellipticity'])
    peak_e = data['ellipticity'][max_e]
    peak_e_radius = data['eff_radius'][max_e]
    # ax.axvline(x=peak_e_radius, ymin=0, ymax=0.4)
    # ax.arrow(peak_e_radius, 0, dx=0, dy=peak_e, length_includes_head=True, width=1)
    ax.annotate("", xy=(peak_e_radius, peak_e), xytext=(peak_e_radius,0),
                arrowprops=dict(arrowstyle="->, head_width=0.5, head_length=2", color='red', linewidth=1.5)
                # arrowprops = dict(headlength=1, headwidth=0.4),
                )

    for f in formats:
        if (bn is None):
            fn = "ellipticity_posangle.%s" % (f)
        else:
            fn = "%s__ellipticity_posangle.%s" % (bn, f)
        print("saving plot to %s" % (fn))
        fig.savefig(fn, bbox_inches='tight')

    pass

def make_plot(data, bn=None, formats=None, pixelscale=1.0, zeropoint=0):

    if (formats is None):
        formats = ['png', 'pdf']

    background_level = 0.97 * numpy.min(data['intens'])
    background_error = 0.1*background_level

    data['intens_bgsub'] = data['intens'] - background_level
    data['intens_error_full'] = numpy.hypot(data['intens_err'], background_error)

    mag_arcsec = -2.5*numpy.log10(data['intens_bgsub'] / pixelscale**2) + zeropoint
    data['mag_arcsec'] = mag_arcsec
    data['mag_arcsec_err'] = data['intens_error_full'] / data['intens'] * 5
    print(numpy.array(data['mag_arcsec_err']))

    data['eff_radius'] = data['sma'] * numpy.sqrt(1. - data['ellipticity']) * pixelscale

    make_plot_ellipticity(data, bn, formats, pixelscale)
    make_plot_positionangle(data, bn, formats, pixelscale)
    make_plot_intensity(data, bn, formats, pixelscale)
    make_plot_surfacebrightness(data, bn, formats, pixelscale)
    make_plot_ellipticity_positionangle(data, bn, formats, pixelscale)


if __name__ == "__main__":

    fn = sys.argv[1]
    try:
        bn = sys.argv[2]
    except:
        bn = None
        pass

    df = pandas.read_csv(fn)
    make_plot(df, bn, pixelscale=1.375, zeropoint=20.5)

