# All-in-one AGES cleaning script. Takes an input cube and can :
# - Apply Hanning smoothing
# - Fit and subtract a spectral baseline to all spectra
# - Apply a MINMED or MEDMED-style subtraction to each spatial bandpass
# - Output a S/N cube in additional to the standard flux cube
# - Edit header keywords (default is to add BUNIT = JY/BEAM)
# All steps are optional but the order is fixed. Although separate scripts are available for each 
# operation, it is conventient to have everything possible in one step (especially since Hanning 
# smoothing requires re-ordering large cubes, which is very tedious to do manually).
# Requires miriad to be installed.
# For the S/N output, the robust noise parameter is used.

import math
import numpy
from astropy.io import fits as pyfits
import os
import shutil
import sys
import warnings
warnings.filterwarnings("ignore")

# *** USER PARAMETERS ***
# Specify the name of the FITS file (must be in current directory and include a '.fits' extension)
infilename = 'WAVESFull.fits'

# Operations
hann = 0 				# Hanning smoothing level (0 to disable). Must be an odd number

# Fit and subtract a polynomial from the spectral baseline (z-axis) for all spectra
fitbaseline = False	
clip = 2.0				# Sigma clipping level when fitting the baseline
niter =5					# Number of iterations of sigma clipping
order =2					# Order of the polynomial to fit

# Fit and subtract spatial bandpass (x-axis) level. Does not fit a polynomial, just estimates
# and subtracts the average value
spatial = 'False' 	# Options are MINMED (minimum of medians), MEDMED (median of medians), or None
nboxes = 5		   	# Number of boxes for calculating individual medians

# Whether to output a S/N cube instead of the standard flux cube (same polynomial as for the
# standard baseline fitting above)
SNCube = False

# List of FITS header keywords and values to add or alter (set headerskeys = None if nothing needs 
# to be altered).
# Format : headerkeys = [[keyword, value], [keyword, value]...]. The "value" can be of any appropriate
# variable type.
headerkeys = None
# Some examples
#headerkeys = [['BUNIT', 'JY/BEAM']]
#headerkeys = [['BUNIT', 'JY/BEAM'], ['CDELT1', -0.0166666672738], ['CDELT2', 0.0166666672738], ['CDELT3', -5449.44175798001]]

# *** END OF USER PARAMETERS ***

basename = infilename.split('.fits')[0]
miriadname = basename+'.mir'


#os.system('cls')

# Changes to the directory where the script is located
abspath = os.path.abspath(__file__)
dname = os.path.dirname(abspath)
os.chdir(dname)


# 1) HANNING SMOOTHING (wrapper for miriad tasks)
if hann > 0:
	# First, convert the FITS file to miriad format 
	print('BEGINNING HANNING SMOOTHING...')
	print('Converting to temporary miriad format...')
	miriadfilename = basename+'.mir'
	
	# Produce an input file for miriad :
	ComFile = open('MiriadInputs.txt','w')

	# Set the 'fits' task parameters
	ComFile.write('fits in='+str(infilename)+' op=xyin out='+str(miriadname)+'\n')
	ComFile.close()
	
	# Make the input file executable
	os.system('chmod +x MiriadInputs.txt')
	
	# Run the script 
	os.system('./MiriadInputs.txt >/dev/null 2>&1')
	
	# Remove the input file
	os.system('rm MiriadInputs.txt >/dev/null 2>&1')
	
	
	# Next, reorder the file
	print('Reordering miriad file...')
	ComFile = open('MiriadInputs.txt','w')

	# Set the 'reorder' task parameters
	ComFile.write('reorder in='+str(miriadname)+' mode=321 out='+str(basename)+'_reorder.mir'+'\n')
	ComFile.close()
	
	# Make the input file executable
	os.system('chmod +x MiriadInputs.txt')
	
	# Run the script 
	os.system('./MiriadInputs.txt >/dev/null 2>&1')
	
	# Remove the input file
	os.system('rm MiriadInputs.txt')
	
	
	# Now we can actually do the Hanning smoothing !
	print('Applying Hanning smoothing, please be patient...')
	ComFile = open('MiriadInputs.txt','w')
	
	# Set the 'hanning' task parameters
	ComFile.write('hanning in='+str(basename)+'_reorder.mir width='+str(hann)+' out='+str(basename)+'_reorder_hann.mir'+'\n')
	ComFile.close()
	
	# Make the input file executable
	os.system('chmod +x MiriadInputs.txt')
	
	# Run the script 
	os.system('./MiriadInputs.txt >/dev/null 2>&1')
	
	# Remove the input file
	os.system('rm MiriadInputs.txt')
	
	
	# Re-re-order the file back to its original state
	print('Resetting axes ordering...')
	ComFile = open('MiriadInputs.txt','w')

	# Set the 'reorder' task parameters
	ComFile.write('reorder in='+str(basename)+'_reorder_hann.mir mode=321 out='+str(basename)+'_h'+str(hann)+'.mir'+'\n')
	ComFile.close()
	
	# Make the input file executable
	os.system('chmod +x MiriadInputs.txt')
	
	# Run the script 
	os.system('./MiriadInputs.txt >/dev/null 2>&1')
	
	# Remove the input file
	os.system('rm MiriadInputs.txt')
	
	
	# Convert it back to FITS format
	print('Converting back to FITS format...')
	ComFile = open('MiriadInputs.txt','w')

	# Set the 'fits' task parameters
	ComFile.write('fits in='+str(basename)+'_h'+str(hann)+'.mir'+' op=xyout out='+str(basename)+'_h'+str(hann)+'.fits'+'\n')
	ComFile.close()
	
	# Make the input file executable
	os.system('chmod +x MiriadInputs.txt')
	
	# Run the script 
	os.system('./MiriadInputs.txt >/dev/null 2>&1')
	
	# Remove the input file
	os.system('rm MiriadInputs.txt')

	
	# Finally, clean-up phase. Delete all the miriad files produced	and then update the base file name to use
	print('Removing temporary files...')
	shutil.rmtree(str(miriadname))							# Basic miriad file
	shutil.rmtree(str(basename)+'_reorder.mir')			# Reordered miriad file
	shutil.rmtree(str(basename)+'_reorder_hann.mir')	# Reodered smoothed miriad file
	shutil.rmtree(str(basename)+'_h'+str(hann)+'.mir')	 # Re-reorderd smoothed miriad file
	
	infilename = str(basename)+'_h'+str(hann)+'.fits'
	basename   = str(basename)+'_h'+str(hann)	
	
	print('Hanning smoothing complete !')
	
	# We also need to store the name of the Hanning smoothed file so we can remove it later - if the user requested other operations, they will produce files as well,
	# and there's no point having a separate file which is Hanning smoothed but without the other operations
	hannfilename = str(basename)+'.fits'
	
	
# If any other operation requested then we'll need to read in the FITS data
if (fitbaseline == True) or (spatial == 'MINMED' or spatial == 'MEDMED') or (SNCube == True) or (headerkeys != '' and headerkeys is not None):
	print('Opening FITS file...')
	# Note that tf we applied Hanning smoothing,the name of the file to use has already been updated, see "infilename" above
		
	FitsFile = pyfits.open(infilename)

	image  = FitsFile[0].data
	header = FitsFile[0].header
	
	FitsFile.close()
	
	sizez = image.shape[0]
	sizey = image.shape[1]
	sizex = image.shape[2]
	
	# 2) UPDATE SPECIFIED KEYWORDS
	if headerkeys is not None:
		for i in range(len(headerkeys)):
			keyword = headerkeys[i][0]
			keyval  = headerkeys[i][1]
			headertuple = (keyword, keyval)
			
			# header.update('RESTFREQ', 1.420405716410E+09)

			#header.update((headertuple))
			header[keyword] = keyval

	
	# 3) POLYNOMIAL FITTING - has to be performed if set explicitly or a SN cube was requested
	if fitbaseline == True or SNCube == True:
		print('Applying polynomial fit...')
		# Create an array to hold the channel numbers
		channel = numpy.array([i for i in range(0,sizez)])

		for xp in range(0,sizex):
			sys.stdout.write('\r')
			sys.stdout.write('X-pixel '+str(xp)+' of '+str(sizex))
			sys.stdout.flush()
			#print('X-pixel',xp,'of',sizex)
			for yp in range(0,sizey):

				spectra = numpy.array(image[:,yp,xp])
				
				nanlist=[]
				for i in range(0,len(spectra)):
					if numpy.isnan(spectra[i]):
						nanlist.append(i)
				
				tempspec = numpy.delete(spectra,nanlist)
				tempchan = numpy.delete(channel,nanlist)
				
				# Create duplicate arrays to fit sigma-clipped polynomials
				#tempspec = numpy.array(spectra)
				
				if len(tempspec)>order :
					for i in range(0,niter):
						length = len(tempchan)
							
						# Fit a polynomial
						temppoly = numpy.polyfit(tempchan, tempspec,order)
						tempmodl = numpy.array([0.0 for k in range(0,length)])
						
						for f in range(0,order+1):
							fn = float(-f + order)
							tempmodl = tempmodl + temppoly[f]*pow(tempchan,fn)
							
						tempspec = tempspec - tempmodl
						
						rms = numpy.std(tempspec)
						av  = numpy.median(tempspec)

						# Find the location of values to remove
						deviants = numpy.where(abs(tempspec - av) > clip*rms)
						
						if (len(deviants)>0 and len(deviants) < length):
							tempspec = numpy.delete(tempspec,deviants)
							tempchan = numpy.delete(tempchan,deviants)
						
					# Fit a polynomial to the original spectrum, masking all deviant values
					poly = numpy.polyfit(tempchan,spectra[tempchan],order)
					model = numpy.array( [0.0 for k in range(0,len(spectra))])
					for f in range(0,order+1):
						fn = float(-f + order)
						model = model + poly[f]*pow(channel,fn)
					
					fitted = spectra - model
					
					# If the user wants the final result to be in SN format :
					if SNCube == True:
						# Get the robust rms, using the final masked spectrum
						medspec = numpy.median(fitted[tempchan])
						absdev = numpy.fabs(fitted[tempchan] - medspec)			# Fabs calculates an array of absolute values
						robrms = 1.4826 * numpy.median(absdev)					# See https://en.wikipedia.org/wiki/Median_absolute_deviation
					
						image[:,yp,xp] = fitted / robrms
						
												
					# Otherwise just use the fitted spectrum
					if SNCube == False:
						image[:,yp,xp] = fitted
						
	
	# 4) SPATIAL BANDPASS
	if spatial == 'MEDMED' or spatial == 'MINMED':
		print('\n'+'Removing spatial bandpass...')
		for yp in range(0,sizey):
			for zp in range(0,sizez):

				sys.stdout.write('\r')
				sys.stdout.write('Y-pixel '+str(yp)+' of '+str(sizey))
	
				vlist = []
				medlist = []
				
				bpass = numpy.array(image[zp,yp,:])
					
				# No. boxes examined so far
				nb = 0
				# No. values in each box	
				lb = int(float(len(bpass))/float(nboxes))
					
				for i in range(0,nboxes):
					for j in range(nb*lb,(nb*lb)+lb):	
						vlist.append(bpass[j])
						
					# Remove nans from vlist. Do this now, rather than removing from the bandpass, because
					# this makes it easier to iterate over the (fixed) length of the bandpass (rather than
					# what would be the variable length of vlist).
					# Have to convert vlist into numpy array first
					vlist = numpy.array(vlist)
					vlist = vlist[numpy.logical_not(numpy.isnan(vlist))]
							
					med =  numpy.median(vlist)
					medlist.append(med)
						
					vlist = []
						
					nb = nb + 1
				
				if spatial == 'MINMED':	
					bmed = numpy.min(medlist)
				if spatial == 'MEDMED':	
					bmed = numpy.median(medlist)
					
				bpass = bpass - bmed
					
				image[zp,yp,:] = bpass


	# Output filename adjustements. Note that these are cumulative.
	
	# Set the filename according to whether we used S/N or not
	if SNCube == True:
		basename = basename+'_p'+str(order)+'_SN'
	if SNCube == False and fitbaseline == True:		
		basename = basename+'_p'+str(order)	
	
	# Extra name adjustment depending on spatial bandpass filter
	if spatial == 'MEDMED':
		basename = basename+'_MEDMED'
	if spatial == 'MINMED':
		basename = basename+'_MINMED'
	
	# Final name adjustment if header keywords updated
	if headerkeys is not None:
		basename = basename+'_UpdateHeader'			


	# Finally we can write the feckin' thing and go home
	print('\n'+'Writing final output file...')
	NewFitsFile = pyfits.writeto(basename+'.fits', image, header=header)
	print('Done !')
	
	
	# And if we had Hanning smoothing, we need to remove the original file, since this will have been incorporated into the other operations we just applied.
	if hann > 0:
		os.system('rm '+hannfilename)
