import numpy
import pyfits
from pyfits import open as pyopen
import os
import glob
import math
from math import *
import scipy
from scipy import optimize as opt

global xg, yg

# Program to fit and remove Gaussians from a data cube. User specifies
# coordinates to examine in a settings file (designed for use with FRELLED).

# In each channel, generates a Gaussian where the peak = peak value in that 
# channel. The center of the Gaussian is varied over the coordinates of the box.
# Residuals are calculated for each center, and the one with the lowest standard
# deviation used to determine the final center coordinates.

# This version also tries to estimate the true peak flux. Once the center is
# found, the peak value is iterated over +/- 2sigma about the measured peak. The
# value is chosen that minimizes the standard deviation (or sum) of the residual.

# Outputs a new FITS file, '*_degaussed.fits'


# Change to directory where script is located
abspath = os.path.abspath(__file__)
dname = os.path.dirname(abspath)
os.chdir(dname)

# Open the parameters file, get the name of the FITS file

settings = open('DeGaussSettings.txt','r')

# Make sure no line breaks in file name !
FitsName = str(settings.readline()).splitlines()[0]
guess = FitsName.split('.fits')[0]

pi = 3.14159265359

FWHM = 3.5		# Beam FWHM in pixels. Found emperically by subjective judgement. 2.7

#FWHM = 2sqrt(2.ln(2)) * sigma
#sigma  = beam/2.354820045

sigma = FWHM / (2.0*sqrt(2.0*log(2.0)))

outfile = str(guess)+'_deguassed.fits'

FitsFile = pyopen(FitsName)

image = FitsFile[0].data
header =FitsFile[0].header

# Create blank FITS files of the same dimensions as this one
gaussimage = numpy.zeros_like(image)
degauimage = numpy.zeros_like(image)


for line in settings:
	midx,midy,minx,maxx,miny,maxy,minz,maxz = line.split()
	midx = int(midx)
	midy = int(midy)
	minx = int(minx)
	maxx = int(maxx)
	miny = int(miny)
	maxy = int(maxy)
	minz = int(minz)
	maxz = int(maxz)
		
	nx = maxx - minx
	ny = maxy - miny
	nz = maxz - minz	
	
	for zp in range(minz,maxz):
	
		print zp
									
		imslice = image[zp,miny:maxy,minx:maxx]
		# We're gonna need gaussian slices
		gaussslice = numpy.zeros_like(imslice)
		residual   = numpy.zeros_like(imslice)
	
		# Find the peak value, ignoring nans			
		chanmax = numpy.nanmax(imslice)
					
		# Find where the maximum is	
		ploc = numpy.where(imslice >= chanmax)
			
		# These are the global pixel coordinates of the peak	
		py = ploc[0]+miny
		py = py[0]
		px = ploc[1]+minx
		px = px[0]
					
		minres = []										
								
		# Try a range of coordinates for the central pixel
		for xc in numpy.arange(midx-2,midx+2,0.25):
			for yc in numpy.arange(midy-2,midy+2,0.25):
			

				# Construct and remove Gaussians assuming every possible center
				for xp in range(0,maxx-minx):
					for yp in range(0,maxy-miny): 
						
						# Central coordinates are global	
						rx = xp - (xc-minx)
						ry = yp - (yc-miny)				
				  
						f = chanmax*exp(-( ((rx*rx)/(2.0*sigma*sigma)) + ((ry*ry)/(2.0*sigma*sigma))))
							  
						gaussslice[yp,xp] = gaussslice[yp,xp] + f
						
				residual = imslice - gaussslice
				
				#if abs(zp-871)<0.01:
				#	print xc, yc, numpy.nanmax(gaussslice)
				
				gaussslice[:,:] = 0.0
				
				
			  #if (abs(xc-171.0) < 0.01 and abs(yc-182.0)<0.01) and abs(zp-871)<0.01:
			  #	print xc,yc,zp
			  #	file1 = open('Residual at x171 y182.txt','w')
			  #	for i in range(0,maxx-minx):
			  #		for j in range(0,maxy-miny):
			  #			file1.write(str(i)+' '+str(j)+' '+str(residual[i,j])+'\n')
			  #	file1.close()
			  #
			  #if (abs(xc-169.0) < 0.01 and abs(yc-180.0)<0.01) and abs(zp-871)<0.01:
			  #	print xc,yc,zp
			  #	file1 = open('Residual at x169 y180.txt','w')
			  #	for i in range(0,maxx-minx):
			  #		for j in range(0,maxy-miny):
			  #			file1.write(str(i)+' '+str(j)+' '+str(residual[i,j])+'\n')
			  #	file1.close()
	
				# Store the rms and central pixel position
				resdev = numpy.std(residual)
				ressum =  abs(numpy.sum(residual))
						
				minres.append([resdev, ressum, xc, yc])
				
		gaussslice[:,:] = 0.0
						
		minres = numpy.asarray(minres)
		
		# Sort by first (zeroth) column, rms
		minres = minres[minres[:,0].argsort()]	
		
		# Get the rms
		rms = minres[0,0]
		
		# Now sort by second (first) column, sum
		#minres = minres[minres[:,1].argsort()]
		
		# Get pixel coordinates
		xc = minres[0,2]
		yc = minres[0,3]	
		
		print 'Using pixel ',xc,yc,' at channel ',zp	
		
		
		# Once the central pixel has been found, use this pixel for the subtraction
		# Now also try varying the flux, to account for noise. We aleady have a good
		# estimate of the rms, use this to define a range.


		# Let's allow the flux to have been affected by +/- 2 sigma
		rmsrange = numpy.arange(-2.0*rms,2.0*rms,rms/2.5)

		minres = []		
		
		
		for noise in rmsrange:
			for xp in range(0,maxx-minx):
				for yp in range(0,maxy-miny): 
					rx = xp - (xc-minx)
					ry = yp - (yc-miny)			
				
					f = (chanmax+noise)*exp(-( ((rx*rx)/(2.0*sigma*sigma)) + ((ry*ry)/(2.0*sigma*sigma))))	
					gaussslice[yp,xp] = gaussslice[yp,xp] + f
			
			residual = imslice - gaussslice
					
			resdev = numpy.std(residual)
			#resdev = abs(numpy.sum(residual))
			minres.append([resdev, noise])
			gaussslice[:,:] = 0.0
			
		# Now find the noise value which minimized the residual
		
		minres = numpy.asarray(minres)
		minres = minres[minres[:,0].argsort()]
		
		noise = minres[0,1]
		
	
		# Use this noise value to construct the final Gaussian
		for xp in range(minx,maxx):
			for yp in range(miny,maxy): 
				rx = xp - xc
				ry = yp - yc			
				
				f = (chanmax+noise)*exp(-( ((rx*rx)/(2.0*sigma*sigma)) + ((ry*ry)/(2.0*sigma*sigma))))	
				gaussimage[zp,yp,xp] = gaussimage[zp,yp,xp] + f
		
			

	degauimage = image - gaussimage

NewFitsFile = pyfits.core.writeto(outfile, degauimage, header=header, clobber='true')


print 'done !'

