#!/usr/bin/env python3

import os
import sys

import astropy.io.fits as pyfits
import numpy
import scipy.ndimage


sex_conf = '/work/cmzs/wise_ngc1300/sex.conf'
sex_param = '/work/cmzs/wise_ngc1300/default.param'
if __name__ == "__main__":

    # run sextractor on both frames
    sex = "sex -c %s -PARAMETERS_NAME %s" % (sex_conf, sex_param)

    wise_stack = [pyfits.PrimaryHDU()]
    wise_binned = [pyfits.PrimaryHDU()]

    fns = sys.argv[1:4]
    binned_data = []
    for i,fn in enumerate(sys.argv[1:5]):
        print("Working on band W%d" % (i+1))

        if (not os.path.isfile("segmentation_w%d.fits" % (i+1))):
            sex_w1 = sex + " -CHECKIMAGE_NAME segmentation_w%d.fits "%(i+1) + fn
            os.system(sex_w1)

        wise_hdu = pyfits.open(fn)
        data = wise_hdu[0].data
        segm = pyfits.open("segmentation_w%d.fits" % (i+1))[0].data


        good_bg = numpy.isfinite(data) & (segm == 0)
        for iter in range(3):
            _stats = numpy.nanpercentile(data[good_bg], [16,50,84])
            _median = _stats[1]
            _sigma = 0.5 * (_stats[2] - _stats[0])
            outlier = (data > (_median + 3*_sigma)) | (data < (_median - 3*_sigma))
            good_bg[outlier] = False
        bg = numpy.nanmedian(data[good_bg])

        data -= bg

        if (i <= 2):
            # needs smoothing for w1-w3
            data = scipy.ndimage.gaussian_filter(
                input=data, sigma=3.2, order=0,
                mode='constant', cval=0.
            )


        wise_stack.append(pyfits.ImageHDU(data=data, header=wise_hdu[0].header, name="WISE_W%d" % (i+1)))

        # bin data 5x5
        binned = data.reshape((819,5,819,5))# .sum(axis=1).sum(axis=1)
        binned = numpy.sum(binned, axis=-1)
        binned = numpy.sum(binned, axis=1)
        print(binned.shape)
        wise_binned.append(pyfits.ImageHDU(data=binned, header=wise_hdu[0].header, name="WISE_W%d" % (i+1)))

        binned_data.append(binned)

    wise_stack_hdu = pyfits.HDUList(wise_stack)
    wise_stack_hdu.writeto("stack.fits", overwrite=True)

    wise_binned_hdu = pyfits.HDUList(wise_binned)
    wise_binned_hdu.writeto("stackbinned.fits", overwrite=True)

    # now extract the image area
    binned_data = numpy.array(binned_data)
    xy01 = [int(sys.argv[i]) for i in range(5,9)]
    x0,x1,y0,y1 = xy01[0], xy01[1], xy01[2], xy01[3]

    # x0 = int(sys.argv[5]), int(sys.argv[6])
    cutout = binned_data[:, y0:y1, x0:x1]
    print(cutout.shape)
    cutout1d = cutout.reshape((4,-1))
    numpy.savetxt("dummywisesed", cutout1d)

    # normalize to w4
    cutout1dnorm = cutout1d / cutout1d[3,:]
    numpy.savetxt("dummywisesednorm", cutout1dnorm)

    # pick 20 brightest %
    brightest20 = numpy.nanpercentile(cutout1d[3,:], 60)
    print(brightest20)
    picked = (cutout1d[3,:] > brightest20)
    print(picked.shape)
    print(picked[:20])
    _1d = cutout1dnorm.T[picked]
    _median = numpy.median(_1d, axis=0)
    print(_median.shape)
    print(_median)

    # # work out the filtering kernel size
    # w1_sigma = 6.1 / w1_hdu[0].header['PXSCAL1'] / 2.355
    # w4_sigma = 12.0 / w4_hdu[0].header['PXSCAL1'] / 2.355
    # print(w1_sigma, w4_sigma)
    #
    # kernel_sigma = numpy.sqrt(w4_sigma**2 - w1_sigma**2)
    # print("kernel-size:", kernel_sigma)
    #
    # w1_smoothed = scipy.ndimage.gaussian_filter(
    #     input=w1, sigma=kernel_sigma, order=0,
    #     mode='constant', cval=0.
    # )
    # pyfits.PrimaryHDU(data=w1, header=w1_hdu[0].header).writeto("w1_bgsub.fits", overwrite=True)
    # pyfits.PrimaryHDU(data=w4, header=w4_hdu[0].header).writeto("w4_bgsub.fits", overwrite=True)
    # pyfits.PrimaryHDU(data=w1_smoothed, header=w1_hdu[0].header).writeto("w1_smoothed.fits", overwrite=True)
    #
    # mag_w1 = w1_hdu[0].header['MAGZP']
    # mag_w4 = w4_hdu[0].header['MAGZP']
    # d_zp = mag_w4 - mag_w1
    # w1_match_w4 = numpy.power(10., -0.4*d_zp)
    # print(w1_match_w4)
    #
    # print(numpy.sum(w1), numpy.sum(w1_smoothed))
    #
    # flux_w1 = 7.32e-12
    # flux_w4 = 5.11e-15 # from galev, based on vega spectrum at 3.4 and 24 mu
    #
    # w1_to_w4 = 1/w1_match_w4 * (5.11e-15/7.32e-12)
    # print(w1_to_w4)
    #
    # scaling_factor = (flux_w1 / numpy.power(10.,0.4*mag_w1)) / (flux_w4 / numpy.power(10., 0.4*mag_w4))
    # print(scaling_factor)
    # scaling_factor = 1.
    # diff = w4 - (w1_smoothed/scaling_factor)/12.
    # pyfits.PrimaryHDU(data=diff, header=w1_hdu[0].header).writeto("w1-w4.fits", overwrite=True)
