# Script written by ChatGPT, tidied up by me. Takes two FITS cubes as input and merges (mosaics) them with correct WCS.
# Overwrites the target output file if it exists.

# IMPORTANT NOTES : 1) Only adjusts the spatial WCS. No transformation is done on the spectral axis at all, i.e. channel 
# 57 of file one will be averaged with channel 57 in file two, regardless of the spectral coordinates. Averaging is done 
# by mean,  weighted by the rms^2 of each spectra (optionally, can produce 2D rms maps).
# 2) Only copies a minimalist set of header keywords between files - see lines 89-98 if you need to adjust these. These
# are all those necessary to ensure a correct WCS display in standard viewers and also to allow for miriad compatibility.
# 3) Because of the noise-weighting, spectra consisting entirely of values of exactly zero will not be handled correctly
# and will be replaced with NaN in the merged cube, even if there's good data in one of the input cubes. Values of zero
# need to be replaced with NaN, see the "ReplaceBlanks.py" (replace zeros) script for this.

import numpy
from astropy.io import fits
from astropy.wcs import WCS
from reproject import reproject_interp, reproject_exact
from reproject.mosaicking import find_optimal_celestial_wcs
import sys
import warnings


# *** USER-SPECIFIED PARAMETERS ***
# Specify the two cubes to merge. Assumed shape is (nchan, ny, nx)
cube1_file = 'Cube1.fits'
cube2_file = 'Cube2.fits'

# Specify the name of the outpt (merged) cube
merge_cube_file = 'MergedCube.fits'

# Also give the name of the rest frequency header keyword, as it appears in the first file. This is to ensure the keyword
# and its value are copied to the merged file correctly
restfrq = 'RESTFREQ'

# Output rms maps ? May be useful for testing
make_rmsmaps = False
# *** END OF USER-SPECIFIED PARAMETERS ***


# Suppress some common pointless warning messages
warnings.filterwarnings("ignore")


print('Opening cubes...')
with fits.open(cube1_file) as hdulist1:
    cube1 = hdulist1[0].data
    header1 = hdulist1[0].header

with fits.open(cube2_file) as hdulist2:
    cube2 = hdulist2[0].data
    header2 = hdulist2[0].header


print('Generating rms maps...')
# Compute the RMS map for each cube as the standard deviation along the spectral axis
rms1_map = numpy.nanstd(cube1, axis=0)
rms2_map = numpy.nanstd(cube2, axis=0)


# Optionally save the rms maps, may be useful for testing
if make_rmsmaps == True:
	fits.writeto('RMSMap1.fits', rms1_map, header1, overwrite=True)
	fits.writeto('RMSMap2.fits', rms2_map, header2, overwrite=True)


# Extract the 2D (celestial) WCS from each header
print('Generating merge cube header...')
wcs1_2d = WCS(header1, naxis=2)
wcs2_2d = WCS(header2, naxis=2)
header1_2d = wcs1_2d.to_header()
header2_2d = wcs2_2d.to_header()

# Use a representative slice (e.g., first spectral channel) to compute the optimal celestial WCS
array_list = [(cube1[0], header1_2d), (cube2[0], header2_2d)]
new_wcs, shape_out = find_optimal_celestial_wcs(array_list)

# Build a new 2D header from the optimal WCS
target_header_2d = new_wcs.to_header()
target_header_2d['NAXIS'] = 2
target_header_2d['NAXIS1'] = shape_out[1]  # width (RA)
target_header_2d['NAXIS2'] = shape_out[0]  # height (Dec)

# Create a new 3D header for the final cube by copying the 2D header and adding spectral info.
new_header = target_header_2d.copy()
nchan = cube1.shape[0]  # Assuming both cubes share the same spectral axis dimensions
new_header['NAXIS'] = 3
new_header['NAXIS3'] = nchan


# Copy spectral keywords from cube1's header (assuming both cubes share the same spectral calibration)
# Also copies the rest frequency (user needs to specify the name as this can vary) and the beam
# parameters
for key in ['CTYPE3', 'CRVAL3', 'CDELT3', 'CRPIX3', restfrq, 'BMAJ', 'BMIN']:
    if key in header1:
        new_header[key] = header1[key]

# Somehow the new header is automatically given the WCSAXES keyword with the wrong value (2 when it should
# be 3), so let's just remove it entirely. This ensures the WCS displays correctly in DS9.		
new_header.remove('WCSAXES')
		
# Reproject the RMS maps to the target spatial grid.
# Since the rms maps are 2D, we use the same 2D headers as for the spatial data.
rmap1, _ = reproject_interp((rms1_map, header1_2d), target_header_2d)
rmap2, _ = reproject_interp((rms2_map, header2_2d), target_header_2d)



print('Creating the merged cube...')
# Initialize the merged cube
#merged_cube = numpy.zeros((nchan, shape_out[0], shape_out[1]))
merged_cube = numpy.full((nchan, shape_out[0], shape_out[1]), numpy.nan)


# Loop over spectral channels to reproject and combine each slice.
for i in range(nchan):
	sys.stdout.write('\r')
	sys.stdout.write('Merging channel '+str(i+1)+' of '+str(nchan)+'...')
	sys.stdout.flush()
	
	# Reproject the i-th slice of each cube onto the new 2D target header
	array1, footprint1 = reproject_interp((cube1[i], header1_2d), target_header_2d)
	array2, footprint2 = reproject_interp((cube2[i], header2_2d), target_header_2d)

	# Identify valid pixels from the footprints
	valid1 = footprint1 > 0
	valid2 = footprint2 > 0
	
	# Create weight maps from the reprojected rms maps as weight = 1/(rms^2). For pixels with no valid data, weight remains zero
	w1 = numpy.zeros_like(array1)
	w2 = numpy.zeros_like(array2)
	#w1 = numpy.full_like(array1, numpy.nan)
	#w2 = numpy.full_like(array2, numpy.nan)
	w1[valid1] = 1.0 / (rmap1[valid1]**2)
	w2[valid2] = 1.0 / (rmap2[valid2]**2)
	
	# Compute the sum of weights
	weight_sum = w1 + w2
	
	# Initialize the combined slice.
	#combined = numpy.zeros_like(array1)
	combined = numpy.full_like(array1, numpy.nan)
	
	# Where both cubes contribute, compute the weighted average.
	both = (w1 > 0) & (w2 > 0) & (weight_sum > 0)
	combined[both] = (w1[both] * array1[both] + w2[both] * array2[both]) / weight_sum[both]
    
	# Where only one cube has valid data, use its value.
	only1 = (w1 > 0) & (w2 == 0)
	only2 = (w2 > 0) & (w1 == 0)
	combined[only1] = array1[only1]
	combined[only2] = array2[only2]
    
	merged_cube[i] = combined


# Save the merged cube.
fits.writeto(merge_cube_file, merged_cube, new_header, overwrite=True)

print('Done !')
