#!/usr/bin/env python3


import numpy
import scipy
import astropy.io.fits as pyfits
import sys
import os
import astropy.table
import astropy.io.votable as votable

def write_galfit(img_hdu, catalog_fn, src_id):

    pass



if __name__ == "__main__":

    img_fn = sys.argv[1]
    segm_fn = sys.argv[2]
    target_id = int(sys.argv[3])
    catalog_fn = sys.argv[4]

    # open segmentation file and turn into galfit mask
    segm_hdu = pyfits.open(segm_fn)
    segm = segm_hdu[0].data
    print(segm.shape)

    segm[segm == target_id] = 0

    mask_hdu = pyfits.PrimaryHDU(header=segm_hdu[0].header,
                                 data=segm)
    mask_fn = "galfit_mask.fits"
    mask_hdu.writeto(mask_fn, overwrite=True)

    img_hdu = pyfits.open(img_fn)
    img = img_hdu[0].data
    img_masked = img.copy()
    print(img_masked.shape)
    img_masked[segm > 0] = numpy.NaN
    masked_hdu = pyfits.PrimaryHDU(header=img_hdu[0].header,
                                   data=img_masked)
    masked_hdu.writeto("image_masked.fits", overwrite=True)

    #
    # Now run galfit
    #

    # write_galfit(img_hdu=img_hdu, catalog_fn=catalog_fn, src_id=target_id)
    vot = votable.parse_single_table(catalog_fn)
    catalog = vot.to_table()
    # catalog = astropy.table.Table.read(catalog_fn)
    # catalog.info()

    src_data = catalog[target_id-1]
    print(src_data) #.info()
    galfit_out_fn = "galfit_out.fits"
    galfit_info = {
        'imgfile': img_fn,  # img_out_fn, #image_fn,
        'srcid': target_id,
        'x1': 0,  # x1,
        'x2': img.shape[1], #x2 - x1,  # x2
        'y1': 0,  # y1,
        'y2': img.shape[0], #y2 - y1,  # y2,
        'pixelscale': 0.18,
        'weight_image': 'none', #_weight,  # weight_out_fn if weight_fn is not None else 'none',
        'galfit_output': galfit_out_fn, #_out,  # galfit_output,
        'bpm': mask_fn,  # segm_out_fn if segmentation_fn is not None else 'none',
        'psf': None, #galfit_psf_option,
        'psf_supersample': 1, #int(psf_supersample),
        'magzero': 0, #magzero,
        'constraints': '', #constraints_opt,
    }

    head_block = """
        A) %(imgfile)s         # Input data image (FITS file)
        B) %(galfit_output)s   # Output data image block
        C) %(weight_image)s                # Sigma image name (made from data if blank or "none") 
        D) %(psf)s   #        # Input PSF image and (optional) diffusion kernel
        E) %(psf_supersample)d                   # PSF fine sampling factor relative to data 
        F) %(bpm)s                # Bad pixel mask (FITS image or ASCII coord list)
        G) %(constraints)s                # File with parameter constraints (ASCII file) 
        H) %(x1)d %(x2)d %(y1)d %(y2)d   # Image region to fit (xmin xmax ymin ymax)
        I) 100    100          # Size of the convolution box (x y)
        J) %(magzero).3f              # Magnitude photometric zeropoint 
        K) %(pixelscale).3f %(pixelscale).3f            # Plate scale (dx dy)    [arcsec per pixel]
        O) regular             # Display type (regular, curses, both)
        P) 0                   # Choose: 0=optimize, 1=model, 2=imgblock, 3=subcomps

    """ % (galfit_info)
    # print(head_block)

    posangle = 90 - src_data['THETA_IMAGE']
    src_info = {
        'x': src_data['X_IMAGE'], #x - x1,
        'y': src_data['Y_IMAGE'], #y - y1,
        'magnitude': src_data['MAG_AUTO'],  # +magzero,
        'halflight_radius': src_data['FLUX_RADIUS'],
        'sersic_n': 1.5,  # src[7],
        'axis_ratio': 1. / src_data['ELONGATION'],  # sextractur uses a/b, galfit needs b/a
        'position_angle': src_data["THETA_IMAGE"],

    }
    object_block = """
        # Object number: 1
         0) sersic                 #  object type
         1) %(x)d  %(y)d  1 1  #  position x, y
         3) %(magnitude).3f     1          #  Integrated magnitude	
         4) %(halflight_radius).3f      1          #  R_e (half-light radius)   [pix]
         5) %(sersic_n).3f      1          #  Sersic index n (de Vaucouleurs n=4) 
         6) 0.0000      0          #     ----- 
         7) 0.0000      0          #     ----- 
         8) 0.0000      0          #     ----- 
         9) %(axis_ratio).3f      1          #  axis ratio (b/a)  
        10) %(position_angle).3f    1          #  position angle (PA) [deg: Up=0, Left=90]
         Z) 0                      #  output option (0 = resid., 1 = Don't subtract)

        # Object number: 2
         0) sky                    #  object type
         1) 0.0000      1          #  sky background at center of fitting region [ADUs]
         2) 0.0000      0          #  dsky/dx (sky gradient in x)
         3) 0.0000      0          #  dsky/dy (sky gradient in y)
         Z) 0                      #  output option (0 = resid., 1 = Don't subtract) 

    """ % src_info
    # print(object_block)

    # feedme_fn = "feedme.%d" % (int(src[4]))
    # feedme_fn = "%s.src%05d.galfeed" % (config_basename, src_id)
    feedme_fn = "galfit.feed"
    with open(feedme_fn, "w") as feedfile:
        feedfile.write("\n".join([l.strip() for l in head_block.splitlines()]))
        feedfile.write("\n".join([l.strip() for l in object_block.splitlines()]))



    # os.system("galfit galfit.feed")

    galfit_hdu = pyfits.open(galfit_out_fn)
    galfit_hdr = galfit_hdu[2].header

    sky_value = float(galfit_hdr['2_SKY'].split()[0])
    print(sky_value)

    galfit_model = galfit_hdu[2].data
    galfit_model_bgsub = galfit_model - sky_value
    galfit_peak = numpy.max(galfit_model_bgsub)
    median_intensity = numpy.median(galfit_model_bgsub > 0)
    print(galfit_peak, median_intensity)

    intensity_sort = numpy.sort(galfit_model_bgsub.flatten())[::-1]
    numpy.savetxt("intensity_sort", intensity_sort)

    cumulative = numpy.cumsum(intensity_sort)
    numpy.savetxt("cumulative", cumulative)

    total_flux = cumulative[-1]
    print(total_flux)
    fraction_flux = 0.75 * total_flux
    # find closest actual number
    diff_flux = numpy.fabs(cumulative - fraction_flux)
    best_match = numpy.argmin(diff_flux)
    print(best_match)
    fraction_intensity = intensity_sort[best_match]
    print(fraction_intensity, fraction_intensity+sky_value)
    bright_cutoff = fraction_intensity+sky_value





    # create 2nd mask, eliminating all bright pixels in the profile to isolate the faint outskirts
    segm2 = segm.copy()
    segm2[(img > bright_cutoff) & (segm == 0)] = 9999
    mask_hdu2 = pyfits.PrimaryHDU(header=segm_hdu[0].header,
                                 data=segm2)
    mask_fn2 = "galfit_mask2.fits"
    mask_hdu2.writeto(mask_fn2, overwrite=True)

    galfit_out_fn2 = "galfit_out2.fits"
    galfit_info = {
        'imgfile': img_fn,  # img_out_fn, #image_fn,
        'srcid': target_id,
        'x1': 0,  # x1,
        'x2': img.shape[1], #x2 - x1,  # x2
        'y1': 0,  # y1,
        'y2': img.shape[0], #y2 - y1,  # y2,
        'pixelscale': 0.18,
        'weight_image': 'none', #_weight,  # weight_out_fn if weight_fn is not None else 'none',
        'galfit_output': galfit_out_fn2, #_out,  # galfit_output,
        'bpm': mask_fn2,  # segm_out_fn if segmentation_fn is not None else 'none',
        'psf': None, #galfit_psf_option,
        'psf_supersample': 1, #int(psf_supersample),
        'magzero': 0, #magzero,
        'constraints': '', #constraints_opt,
    }

    head_block2 = """
        A) %(imgfile)s         # Input data image (FITS file)
        B) %(galfit_output)s   # Output data image block
        C) %(weight_image)s                # Sigma image name (made from data if blank or "none") 
        D) %(psf)s   #        # Input PSF image and (optional) diffusion kernel
        E) %(psf_supersample)d                   # PSF fine sampling factor relative to data 
        F) %(bpm)s                # Bad pixel mask (FITS image or ASCII coord list)
        G) %(constraints)s                # File with parameter constraints (ASCII file) 
        H) %(x1)d %(x2)d %(y1)d %(y2)d   # Image region to fit (xmin xmax ymin ymax)
        I) 100    100          # Size of the convolution box (x y)
        J) %(magzero).3f              # Magnitude photometric zeropoint 
        K) %(pixelscale).3f %(pixelscale).3f            # Plate scale (dx dy)    [arcsec per pixel]
        O) regular             # Display type (regular, curses, both)
        P) 0                   # Choose: 0=optimize, 1=model, 2=imgblock, 3=subcomps

    """ % (galfit_info)
    # print(head_block)

    # posangle = 90 - src_data['THETA_IMAGE']
    src_info = {
        'x': src_data['X_IMAGE'], #x - x1,
        'y': src_data['Y_IMAGE'], #y - y1,
        'magnitude': src_data['MAG_AUTO'],  # +magzero,
        'halflight_radius': src_data['FLUX_RADIUS'],
        'sersic_n': 1.5,  # src[7],
        'axis_ratio': 1. / src_data['ELONGATION'],  # sextractur uses a/b, galfit needs b/a
        'position_angle': src_data["THETA_IMAGE"],

    }
    object_block2 = """
        # Object number: 1
         0) sersic                 #  object type
         1) %(x)d  %(y)d  0 0  #  position x, y
         3) %(magnitude).3f     1          #  Integrated magnitude	
         4) %(halflight_radius).3f      1          #  R_e (half-light radius)   [pix]
         5) %(sersic_n).3f      1          #  Sersic index n (de Vaucouleurs n=4) 
         6) 0.0000      0          #     ----- 
         7) 0.0000      0          #     ----- 
         8) 0.0000      0          #     ----- 
         9) %(axis_ratio).3f      1          #  axis ratio (b/a)  
        10) %(position_angle).3f    1          #  position angle (PA) [deg: Up=0, Left=90]
         Z) 0                      #  output option (0 = resid., 1 = Don't subtract)

        # Object number: 2
         0) sky                    #  object type
         1) 0.0000      1          #  sky background at center of fitting region [ADUs]
         2) 0.0000      0          #  dsky/dx (sky gradient in x)
         3) 0.0000      0          #  dsky/dy (sky gradient in y)
         Z) 0                      #  output option (0 = resid., 1 = Don't subtract) 

    """ % src_info
    # print(object_block)

    # feedme_fn = "feedme.%d" % (int(src[4]))
    # feedme_fn = "%s.src%05d.galfeed" % (config_basename, src_id)
    feedme_fn2 = "galfit.feed2"
    with open(feedme_fn2, "w") as feedfile:
        feedfile.write("\n".join([l.strip() for l in head_block2.splitlines()]))
        feedfile.write("\n".join([l.strip() for l in object_block2.splitlines()]))

    img_masked2 = img.copy()
    img_masked2[segm2 != 0] = numpy.NaN
    masked_hdu2 = pyfits.PrimaryHDU(header=img_hdu[0].header,
                                   data=img_masked2)
    masked_hdu2.writeto("image_masked2.fits", overwrite=True)


    # run galfit
    # os.system("galfit galfit.feed2")

    # open the new output file, extract the image with the outer disk subtracted
    galfit2_hdu = pyfits.open(galfit_out_fn2)
    residual = galfit2_hdu[3].data
    residual_hdu = pyfits.PrimaryHDU(header=img_hdu[0].header,
                                     data=residual)
    residual_fn = "residual.fits"
    residual_hdu.writeto(residual_fn, overwrite=True)