#!/usr/bin/env python3

import os
import sys
import numpy
import astropy.io.fits as pyfits
import matplotlib
import matplotlib.path as mpltPath

if __name__ =="__main__":

    mask_in_fn = sys.argv[1]
    region_fn = sys.argv[2]
    mask_out_fn = sys.argv[3]

    hdu = pyfits.open(mask_in_fn)
    mask_in = hdu[0].data

    mask_out = mask_in.copy()

    iy,ix = numpy.indices(mask_in.shape)
    print(iy)

    i_xy = numpy.hstack((ix.reshape((-1,1)),iy.reshape((-1,1)))) + [1,1]
    print(i_xy)

    max_mask_id = numpy.max(mask_in)
    start_number = numpy.power(10, numpy.ceil(numpy.log10(max_mask_id)))
    print(max_mask_id, start_number)

    # read ds9 region file
    polygon_id = 0
    with open(region_fn, "r") as reg:
        lines = reg.readlines()
        for line in lines:
            if (not line.startswith("polygon(")):
                # we only care about polygon regions
                continue

            polygon_id += 1
            coords = numpy.array([float(xy) for xy in line.split("polygon(")[1].split(")")[0].split(",")]).reshape((-1,2))

            path = mpltPath.Path(coords)
            inside2 = path.contains_points(i_xy)

            poly_mask = inside2.reshape((mask_in.shape))
            mask_out[poly_mask] = int(start_number + polygon_id)

            # print(coords)

            print(line)

    phdu = pyfits.PrimaryHDU(data=mask_out.astype(numpy.int))
    phdu.writeto(mask_out_fn, overwrite=True)
    print("Results written to %s" % (mask_out_fn))
