t#!/usr/bin/env python3


import os
import sys
import astropy.io.fits as pyfits
import numpy
import pandas
import photutils
import photutils.isophote
from astropy.io import votable

print(photutils.__path__)

if __name__ == "__main__":

    fn = sys.argv[1]
    catalog_fn = sys.argv[2]
    source_id = int(sys.argv[3])
    segm_fn = sys.argv[4]

    # read image
    img_hdu = pyfits.open(fn)
    img = img_hdu[0].data
    img_masked = numpy.ma.array(img)

    # read segmentation frame
    segm_hdu = pyfits.open(segm_fn)
    segm = segm_hdu[0].data
    use_for_fit = ((segm == 0) | (segm == source_id)) & numpy.isfinite(img)
    img_masked.mask = ~use_for_fit

    # read catalog
    vot = votable.parse_single_table(catalog_fn)
    catalog = vot.to_table()
    print(catalog)
    print(source_id)
    right_source = (catalog['NUMBER'] == source_id)
    source_data = catalog[right_source][0]
    print("SOURCE-coord: ", source_data['X_IMAGE'], source_data['Y_IMAGE'])

    # sys.exit(-1)

    rerun = True
    surfprofile_csv = "surfprofile.csv"
    if (rerun and os.path.isfile(surfprofile_csv)):
        df = pandas.read_csv(surfprofile_csv)
    else:
        print("Preparing ellipse fit")
        pyfits.PrimaryHDU(data=img_masked.filled(0)).writeto("dummy.fits", overwrite=True)

        # prepare ellipse fitting
        geo = photutils.isophote.EllipseGeometry(
            x0 = source_data['X_IMAGE'],
            y0 = source_data['Y_IMAGE'],
            sma=10, eps=0.01, pa=0,
        )
        # geo.find_center(img_masked)
        print(geo)
        ellipse = photutils.isophote.Ellipse(img_masked, geometry=geo)
        print("ellipse:", ellipse)

        # df = pandas.read_csv("surfprofile.csv")
        print("Fitting image")
        isophot_list = ellipse.fit_image(
            # integrmode='median',
            sclip=3.0, nclip=3, fflag=0.7,
            maxsma=150,
            minsma=5,
        )
        print("done fitting")
        print(isophot_list.to_table())

        df = isophot_list.to_table().to_pandas()
        df.to_csv(surfprofile_csv)
        print("Saved surface brightness profile as CSV")


    #
    # convert all ellipses into ds9 format
    #
    ds9_reg_fn = "ellipse_fit.reg"
    with open(ds9_reg_fn, "w") as ds9_reg:
        print("""\
# Region file format: DS9 version 4.1
global color=green dashlist=8 3 width=1 font="helvetica 10 normal roman" select=1 highlite=1 dash=0 fixed=0 edit=1 move=1 delete=1 include=1 source=1
image""", file=ds9_reg)
        for index, e in df.iterrows():
            # print(e)
            print("ellipse(%f,%f,%f,%f,%f)" % (e['x0'], e['y0'], e['sma'], (1.-e['ellipticity'])*e['sma'], e['pa']),
                  file=ds9_reg)


    #
    # now try to create a model of the bar based on the ellipse data, and compare it to what it
    # would have been without the bar
    #
    i_peak_ellipticity = numpy.argmax(numpy.array(df['ellipticity']))
    peak_ellipticity = df['ellipticity'][i_peak_ellipticity]
    bar_sma = df['sma'][i_peak_ellipticity]
    print("Found peak ellipticity of e=%4f at sma=%.2f [%d]" % (
        peak_ellipticity, bar_sma, i_peak_ellipticity))

    # create a full index array for image
    idx_y, idx_x = numpy.indices(img.shape)
    # rotate coords to account for position angle

    model_data_full = numpy.zeros_like(img)
    model_data_round = numpy.zeros_like(img)
    prev_semiminor = prev_semimajor = 0
    for index, e in df[:i_peak_ellipticity+1].iterrows():

        d_intensity = e['intens'] - df.iloc[index+1]['intens']
        if (d_intensity < 0): d_intensity = 0

        dx = idx_x - e['x0'] #source_data['X_IMAGE']
        dy = idx_y - e['y0'] #source_data['Y_IMAGE']
        posangle = e['pa']
        da = numpy.cos(numpy.radians(posangle)) * dx + numpy.sin(numpy.radians(posangle)) * dy
        db = -numpy.sin(numpy.radians(posangle)) * dx + numpy.cos(numpy.radians(posangle)) * dy
        semimajor = e['sma']
        semiminor = e['sma'] * (1-e['ellipticity'])
        # if (semiminor < prev_semiminor): semiminor = prev_semiminor
        # if (semimajor < prev_semimajor): semimajor = prev_semimajor
        r_ellipse = (da/semimajor)**2 + (db/semiminor)**2
        in_ellipse = r_ellipse <= 1

        fake_semiminor = semiminor
        fake_semimajor = semiminor / (1-0.2)
        fake_r_ellipse = (da/fake_semimajor)**2 + (db/fake_semiminor)**2
        fake_in_ellipse = fake_r_ellipse <= 1

        model_data_this_ellipse = numpy.zeros_like(img)
        model_data_this_ellipse[in_ellipse] = d_intensity  #  e['intens']
    model_data_full += model_data_this_ellipse
        # affected_pixels = model_data_this_ellipse > model_data_full
        # model_data_full[affected_pixels] = model_data_this_ellipse[affected_pixels]

        fake_this_ellipse = numpy.zeros_like(img)
        fake_this_ellipse[fake_in_ellipse] = d_intensity
        model_data_round += fake_this_ellipse

        # fake_this_ellipse[fake_in_ellipse] = e['intens']
        # fake_affected_pixels = fake_this_ellipse > model_data_round
        # model_data_round[fake_affected_pixels] = fake_this_ellipse[fake_affected_pixels]

        # pyfits.PrimaryHDU(data=r_ellipse).writeto("r_ellipse.fits", overwrite=True)
        # pyfits.PrimaryHDU(data=in_ellipse.astype(numpy.int)).writeto("in_ellipse.fits", overwrite=True)

        prev_semiminor = semiminor
        prev_semimajor = semimajor
        # break

    pyfits.PrimaryHDU(data=model_data_full).writeto("model_data_full.fits", overwrite=True)
    pyfits.PrimaryHDU(data=model_data_round).writeto("model_data_round.fits", overwrite=True)

    pyfits.PrimaryHDU(data=(img-model_data_full)).writeto("model_data_full_residuals.fits", overwrite=True)
    pyfits.PrimaryHDU(data=(img-model_data_round)).writeto("model_data_round_residuals.fits", overwrite=True)
    pyfits.PrimaryHDU(data=(model_data_full-model_data_round)).writeto("model_data_bar.fits", overwrite=True)
    pyfits.PrimaryHDU(data=(img-model_data_full+model_data_round)).writeto("model_data_barsub.fits", overwrite=True)

    #
    # build model galaxy based on SBP
    #
    # print("Reconstructing image from ellipse fit")
    # model_image = photutils.isophote.build_ellipse_model(
    #     img.shape, isophot_list
    # )
    # residual_image = img - model_image
    # out_hdu = pyfits.HDUList(
    #     pyfits.PrimaryHDU(),
    #     pyfits.ImageHDU(header=img[0].header, data=img, name="IMAGE"),
    #     pyfits.ImageHDU(header=img[0].header, data=model_image, name="MODEL"),
    #     pyfits.ImageHDU(header=img[0].header, data=residual_image, name="RESIDUALS"),
    # )
    # out_hdu.writeto("ellipsefit.fits", overwrite=True)
