#!/usr/bin/env python3

import os
import sys
import astropy.io.fits as pyfits

import numpy
import scipy.optimize
import scipy.ndimage


#
# now optimize the difference image
#
def smooth_image(p, good):
    filtered = scipy.ndimage.gaussian_filter(good, [p[0], p[0]], order=0, mode='constant', cval=0)
    return filtered


def diffimage(p, good, bad):
    #print(p)
    filtered = smooth_image(p, good)
    diff = bad - filtered
    return numpy.square(diff).ravel()




if __name__ == "__main__":

    fn_good = sys.argv[1]
    fn_bad = sys.argv[2]

    region_fn = sys.argv[3]

    #
    # now open the two fits files, and extract the image region defined in the region file
    #
    good_hdu = pyfits.open(fn_good)
    bad_hdu = pyfits.open(fn_bad)

    #
    # find background estimate for entire frame
    #
    if (not os.path.isfile("default.param")):
        os.system("sex -dp > default.param")

    for (fn,hdu) in [(fn_good, good_hdu), (fn_bad, bad_hdu)]:
        segmentation_fn = fn[:-5]+".segmentation.fits"
        if (not os.path.isfile(segmentation_fn)):
            # create segmentation mask so we can isolate and estimate the sky background
            sex_cmd = "sex -PARAMETERS_NAME matchpsf.param -CHECKIMAGE_TYPE SEGMENTATION -CHECKIMAGE_NAME %s -FILTER N %s" % (
                segmentation_fn, fn
            )
            os.system(sex_cmd)
        print("Reading segmentation file %s" % (segmentation_fn))
        segmentation_hdu = pyfits.open(segmentation_fn)
        source_contaminated = (segmentation_hdu[0].data != 0)

        img = hdu[0].data.copy()
        img[source_contaminated] = numpy.NaN

        # also open the weight image
        weight_fn = fn[:-5]+".weight.fits"
        if (os.path.isfile(weight_fn)):
            print("Handling weight file %s" % (weight_fn))
            weight_hdu = pyfits.open(weight_fn)
            img[weight_hdu[0].data <= 0] = numpy.NaN
            weight_hdu.close()

        # now scatter some boxes around the image and measure background levels
        boxsize = 10
        nboxes = 5000
        lowerleft_ = numpy.random.rand(nboxes,2) * numpy.array([img.shape[0]-2*boxsize, img.shape[1]-2*boxsize])
        lowerleft = numpy.round(lowerleft_, 0).astype(numpy.int)
        upperright = lowerleft + 2*boxsize

        median_levels = numpy.zeros((nboxes,3))
        for i in range(nboxes):
            x1,y1 = lowerleft[i]
            x2,y2 = upperright[i]
            median_levels[i,0] = numpy.nanmedian(img[y1:y2, x1:x2])
            median_levels[i,[1,2]] = 0.5*numpy.array([x1+x2, y1+y2])

        #print(lowerleft[:10], upperright[:10])
        #print(median_levels[:20])
        numpy.savetxt(fn+".background", median_levels)

        good_data = numpy.ones((nboxes), dtype=numpy.bool)
        #print(good_data[:5])
        for iter in range(3):
            _perc = numpy.nanpercentile(median_levels[:,0][good_data], [16,50,84])
            _sigma = 0.5*(_perc[2]-_perc[0])
            _median = _perc[1]
            print(iter, _median, _sigma)
            outlier = (median_levels[:,0] > (_median+3*_sigma)) | (median_levels[:,0] < (_median-3*_sigma))
            good_data[outlier] = False
        bglevel = numpy.nanmedian(median_levels[:,0][good_data])

        hdu[0].data -= bglevel
        bgsub_fn = fn[:-5]+".bgsub.fits"
        if (not os.path.isfile(bgsub_fn)):
            hdu.writeto(bgsub_fn, overwrite=True)

    # sys.exit(0)

    #
    # open the region file, extract region to work on
    #
    image_regions = []
    with open(region_fn, "r") as reg:
        lines = reg.readlines()
        for line in lines:
            if (not line.startswith("box(")):
                continue
            # now we have the relevant line
            items = numpy.array([float(f) for f in line.split("box(")[1].split(")")[0].split(",")])
            i_items = numpy.round(items, 0).astype(numpy.int32)

            # print(items)
            # print(i_items)

            x1 = i_items[0] - i_items[2]
            x2 = i_items[0] + i_items[2]
            y1 = i_items[1] - i_items[3]
            y2 = i_items[1] + i_items[3]
            # print(x1,x2,y1,y2)
            image_regions.append( [x1,x2,y1,y2] )

    best_fit_results = []
    bg_levels = []
    for starnumber, box_region in enumerate(image_regions):

        x1,x2,y1,y2 = box_region

        cutout_good = good_hdu[0].data[y1-1:y2, x1-1:x2]
        cutout_bad = bad_hdu[0].data[y1-1:y2, x1-1:x2]

        # print(cutout_good.shape, cutout_bad.shape)

        #
        # now we have two matching small patches, that both should contain a single star
        #

        #
        # estimate some background from the image edges
        #
        mask = numpy.ones_like(cutout_good, dtype=numpy.bool)
        mask[5:-5, 5:-5] = False
        # print(mask[:10, :10])

        bg_good = numpy.nanmedian(cutout_good[mask])
        bg_bad = numpy.nanmedian(cutout_bad[mask])
        # print(bg_good, bg_bad)
        bg_levels.append([bg_good, bg_bad])

        cutout_good -= bg_good
        cutout_bad -= bg_bad

        # cutout_test = scipy.ndimage.gaussian_filter(cutout_good, [2,2], order=0, mode='constant', cval=0)
        # diff = cutout_bad - cutout_test
        diff = cutout_bad - cutout_good

        # print("GOOD: %f" % (numpy.sum(cutout_good)))
        # print("BAD:  %f" % (numpy.sum(cutout_bad)))
        # print("DIFF: %f" % (numpy.sum(diff)))

        p_init = [4.]
        fit_results = scipy.optimize.leastsq(
            func=diffimage,
            x0=p_init,
            args=(cutout_good,cutout_bad),
            full_output=True,
        )
        # print(fit_results)

        p_best, cov, some_dict, _, _ = fit_results
        print("Star %d: %.5f" % (starnumber+1, p_best))

        best_smoothed = smooth_image(p_best, cutout_good)

        imglist = [
            pyfits.PrimaryHDU(),
            pyfits.ImageHDU(cutout_good, name="GOOD"),
            pyfits.ImageHDU(cutout_bad, name="BAD"),
            # pyfits.ImageHDU(cutout_test, name="TEST"),
            # pyfits.ImageHDU(diff, name="DIFF"),
            pyfits.ImageHDU(best_smoothed, name="BEST"),
            pyfits.ImageHDU((cutout_bad-best_smoothed), name="BESTDIFF"),
        ]
        hdu_demo = pyfits.HDUList(imglist)
        hdu_demo.writeto("dummydemo_%d.fits" % (starnumber), overwrite=True)


        best_fit_results.append(p_best)

    best_fit_results = numpy.fabs(numpy.array(best_fit_results))
    print(best_fit_results)

    print("star background levels:", numpy.array(bg_levels))
    bg_levels = numpy.array(bg_levels)

    best_correction = numpy.median(best_fit_results)
    print("Final: %f" % (best_correction))

    corrected = smooth_image([best_correction], good_hdu[0].data)
    print("Done applying correction")

    good_hdu[0].data = corrected
    good_hdu.writeto("corrected.fits", overwrite=True)

    mean_backgrounds = numpy.nanmedian(bg_levels, axis=0)
    print(mean_backgrounds)

    # bad_hdu[0].data -= mean_backgrounds[1]
    # good_hdu[0].data -= mean_backgrounds[0]

    print("writing difference image")
    diff = bad_hdu[0].data - corrected
    bad_hdu[0].data = diff

    try:
        outfn = sys.argv[4]
    except:
        outfn = "difference.fits"
    bad_hdu.writeto(outfn, overwrite=True)

