#!/usr/bin/env python3

import os
import sys

import astropy.io.fits as pyfits
import numpy
import scipy.ndimage


sex_conf = '/work/cmzs/wise_ngc1300/sex.conf'
sex_param = '/work/cmzs/wise_ngc1300/default.param'
if __name__ == "__main__":

    w1_fn = sys.argv[1]
    w4_fn = sys.argv[2]

    # run sextractor on both frames
    sex = "sex -c %s -PARAMETERS_NAME %s" % (sex_conf, sex_param)
    if (not os.path.isfile("segmentation_w1.fits")):
        sex_w1 = sex + " -CHECKIMAGE_NAME segmentation_w1.fits " + w1_fn
        os.system(sex_w1)
    if (not os.path.isfile("segmentation_w4.fits")):
        sex_w4 = sex + " -CHECKIMAGE_NAME segmentation_w4.fits " + w4_fn
        os.system(sex_w4)

    w1_hdu = pyfits.open(w1_fn)
    w1 = w1_hdu[0].data
    w1_segm = pyfits.open("segmentation_w1.fits")[0].data

    w4_hdu = pyfits.open(w4_fn)
    w4 = w4_hdu[0].data
    w4_segm = pyfits.open("segmentation_w4.fits")[0].data
    print(w1_hdu[0].header['PXSCAL1'], w4_hdu[0].header['PXSCAL1'])

    # find backgrounds in both frames
    bg_levels = []
    for data,segm in [(w1,w1_segm), (w4,w4_segm)]:

        good_bg = numpy.isfinite(data) & (segm == 0)
        for iter in range(3):
            _stats = numpy.nanpercentile(data[good_bg], [16,50,84])
            _median = _stats[1]
            _sigma = 0.5 * (_stats[2] - _stats[0])
            outlier = (data > (_median + 3*_sigma)) | (data < (_median - 3*_sigma))
            good_bg[outlier] = False
        bg = numpy.nanmedian(data[good_bg])

        print(_median, _sigma, bg)

        data -= bg

    # work out the filtering kernel size
    w1_sigma = 6.1 / w1_hdu[0].header['PXSCAL1'] / 2.355
    w4_sigma = 12.0 / w4_hdu[0].header['PXSCAL1'] / 2.355
    print(w1_sigma, w4_sigma)

    kernel_sigma = numpy.sqrt(w4_sigma**2 - w1_sigma**2)
    print("kernel-size:", kernel_sigma)

    w1_smoothed = scipy.ndimage.gaussian_filter(
        input=w1, sigma=kernel_sigma, order=0,
        mode='constant', cval=0.
    )
    pyfits.PrimaryHDU(data=w1, header=w1_hdu[0].header).writeto("w1_bgsub.fits", overwrite=True)
    pyfits.PrimaryHDU(data=w4, header=w4_hdu[0].header).writeto("w4_bgsub.fits", overwrite=True)
    pyfits.PrimaryHDU(data=w1_smoothed, header=w1_hdu[0].header).writeto("w1_smoothed.fits", overwrite=True)

    mag_w1 = w1_hdu[0].header['MAGZP']
    mag_w4 = w4_hdu[0].header['MAGZP']
    d_zp = mag_w4 - mag_w1
    w1_match_w4 = numpy.power(10., -0.4*d_zp)
    print(w1_match_w4)

    print(numpy.sum(w1), numpy.sum(w1_smoothed))

    flux_w1 = 7.32e-12
    flux_w4 = 5.11e-15 # from galev, based on vega spectrum at 3.4 and 24 mu

    w1_to_w4 = 1/w1_match_w4 * (5.11e-15/7.32e-12)
    print(w1_to_w4)

    scaling_factor = (flux_w1 / numpy.power(10.,0.4*mag_w1)) / (flux_w4 / numpy.power(10., 0.4*mag_w4))
    print(scaling_factor)
    scaling_factor = 1.
    diff = w4 - (w1_smoothed/scaling_factor)/12.09
    pyfits.PrimaryHDU(data=diff, header=w1_hdu[0].header).writeto("w1-w4.fits", overwrite=True)
