#!/usr/bin/env python3


import os
import sys
import astropy.io.fits as pyfits
import numpy
import pandas
import photutils
import photutils.isophote
from astropy.io import votable
import scipy.interpolate
import scipy.ndimage


if __name__ == "__main__":

    fn = sys.argv[1]
    catalog_fn = sys.argv[2]
    source_id = int(sys.argv[3])
    segm_fn = sys.argv[4]

    # read image
    img_hdu = pyfits.open(fn)
    img = img_hdu[0].data
    # img_masked = img.copy() #numpy.ma.array(img)
    img_masked = numpy.ma.array(img)

    # read catalog
    vot = votable.parse_single_table(catalog_fn)
    catalog = vot.to_table()
    print(catalog)
    print(source_id)
    right_source = (catalog['NUMBER'] == source_id)
    source_data = catalog[right_source][0]
    print("SOURCE-coord: ", source_data['X_IMAGE'], source_data['Y_IMAGE'])

    # read segmentation frame
    segm_hdu = pyfits.open(segm_fn)
    segm = segm_hdu[0].data
    use_for_fit = (segm == 0) | (segm == source_id)
    # img_masked[~use_for_fit] = -1e9
    img_masked.mask = ~use_for_fit

    background_pixels = (segm == 0)
    bg_stats = numpy.nanpercentile(img[background_pixels], [16,50,84])
    bg_median = bg_stats[1]
    bg_sigma = 0.5*(bg_stats[2]-bg_stats[0])
    print("background: %f +/- %f" % (bg_median, bg_sigma))


    # #
    # # take center position from sextractor, and rotate image by 90 degrees around the source
    # #
    # idx_y, idx_x = numpy.indices(img.shape, dtype=numpy.float)
    # idx_xy = numpy.hstack((idx_x.reshape((-1,1)), idx_y.reshape((-1,1))))
    # rel_xy = idx_xy - [source_data['X_IMAGE'], source_data['Y_IMAGE']]
    #
    # angle = 90.
    # rot_x =  rel_xy[:,0] * numpy.cos(numpy.radians(angle)) + rel_xy[:,1] * numpy.sin(numpy.radians(angle))
    # rot_y = -rel_xy[:,0] * numpy.sin(numpy.radians(angle)) + rel_xy[:,1] * numpy.cos(numpy.radians(angle))
    # rot_xy = numpy.hstack((rot_x.reshape((-1,1)), rot_y.reshape((-1,1))))
    # print(rot_xy.shape)
    #
    # f = scipy.interpolate.interp2d(
    #     x=idx_x.ravel(), y=idx_y.ravel(), z=img.ravel(),
    #     kind='linear',
    #     bounds_error=False, fill_value=0,
    # )
    # print("coming up: interpolate")
    # rotated_img = f(rot_x, rot_y)
    # print(rotated_img.shape)

    #
    # take center position of galaxy from source extractor, and pad the image so the
    # galaxy is at the center of the array
    #
    # adapted from https://stackoverflow.com/questions/25458442/rotate-a-2d-image-around-specified-origin-in-python/25459080
    #
    img_list = []
    x0 = int(numpy.round(source_data['X_IMAGE'], 0))  # - 1
    y0 = int(numpy.round(source_data['Y_IMAGE'], 0))  # - 1
    padX = [img.shape[1] - x0, x0]
    padY = [img.shape[0] - y0, y0]
    imgP = numpy.pad(img_masked, [padY, padX], 'constant')
    pyfits.PrimaryHDU(data=imgP).writeto("img_padded.fits", overwrite=True)

    for rotation_angle in numpy.arange(30,151,5): #[45,90,135]:
        # rotation_angle = 90
        imgR = scipy.ndimage.rotate(imgP, rotation_angle, reshape=False)
        imgR[imgR < -1e6] = numpy.NaN
        # Note that we disallow reshaping the image, since we'll crop the image ourselves.

        # Crop the image such that the pivot point is at its original position. Therefore, we simply reverse the padding from step 1:
        imgC = imgR[padY[0] : -padY[1], padX[0] : -padX[1]]
        pyfits.PrimaryHDU(data=imgC).writeto("img_rotated_%03d.fits" % (rotation_angle), overwrite=True)

        img_diff = img - imgC
        pyfits.PrimaryHDU(data=img_diff).writeto("img_difference_%03d.fits" % (rotation_angle), overwrite=True)

        img_list.append(imgC)

    print(len(img_list))
    img_stack = numpy.array(img_list)
    print(img_stack.shape)
    img_median = numpy.nanmedian(img_stack, axis=0)
    print(img_median.shape)
    pyfits.PrimaryHDU(data=img_median).writeto("img_median.fits", overwrite=True)

    img_diff = img - img_median
    pyfits.PrimaryHDU(data=img_diff).writeto("img_difference_median.fits", overwrite=True)

    noise_image = numpy.random.randn(img.shape[0], img.shape[1]) * bg_sigma
    print(noise_image.shape)
    # pyfits.
    fill_in_noise = (img_diff < 0) | ~use_for_fit
    img_diff[fill_in_noise] = noise_image[fill_in_noise]
    pyfits.PrimaryHDU(data=img_diff).writeto("img_difference_median2.fits", overwrite=True)

    bar_model = img_diff
    residuals = img - bar_model
    pyfits.PrimaryHDU(data=residuals).writeto("img_difference_residuals.fits", overwrite=True)
    # bar_model[~use_for_fit] = 0
