# Script to convert a series of FITS data files into a sequences of png images.
# Specially designed for the massive HI4PI data set, which is difficult to process
# as a single FITS file.
# Adapted from a FRELLED script and stripped to the bare minimum of functionality.

import os
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
try:
	import pyfits
except:
	from astropy.io import fits as pyfits
import glob
import math
import os
import sys
import multiprocessing
import time
import traceback
from sys import platform

def clearscreen():
	try:
		if platform == 'win32':
			os.system('cls')
		else:
			os.system('clear')
	finally:
		pass


# Change to directory where script is located
abspath = os.path.abspath(__file__)
dname = os.path.dirname(abspath)
os.chdir(dname)

clearscreen()



Smin = 0.1
Smax = 150.0
colours = 1
logs = True
minchan = 200
maxchan = 750

basename = 'Channel_'


# Available colours from matplotlib
# 1 = grey
# 2 = hot
# 3 = cold
# 4 = fire
# 5 = red 
# 6 = yel
# 7 = green
# 8 = cyan
# 9 = purple
# 10 = high contrast


# Define custom colourbars
# COLD
cdict = {'red': ((0.0, 0.0, 0.0),
                 (0.75, 0.0, 0.0),
                 (1.0, 1.0, 1.0)),
	     'green': ((0.0, 0.0, 0.0),
                 (0.25, 0.0,0.0),
                 (0.75, 1.0, 1.0),
                 (1.00, 1.0, 1.0)),
         'blue': ((0.0, 0.0, 0.0),
                  (0.25, 1.0, 1.0),
                  (1.0, 1.0, 1.0))} 

cmap1 = matplotlib.colors.LinearSegmentedColormap('cold',cdict,N=256)
plt.register_cmap(cmap=cmap1)

# FIRE
cdict2 = {'red': ((0.0, 0.0, 0.0),
                 (0.25, 1.0, 1.0),
                 (1.0, 1.0, 1.0)),
	     'green': ((0.0, 0.0, 0.0),
		           (0.25,0.0, 0.0),
                 (0.75, 1.0, 1.0),
                 (1.00, 1.0, 1.0)),
         'blue': ((0.0, 0.0, 0.0),
                  (0.75, 0.0, 0.0),
                  (1.0, 1.0, 1.0))} 

cmap2 = matplotlib.colors.LinearSegmentedColormap('fire',cdict2,N=256)
plt.register_cmap(cmap=cmap2)

# RED
cdict3 = {'red': ((0.0, 0.0, 0.0),
                 (0.5, 1.0, 1.0),
                 (1.0, 1.0, 1.0)),
	     'green': ((0.0, 0.0, 0.0),
                 (0.5, 0.0, 0.0),
                 (1.00, 1.0, 1.0)),
         'blue': ((0.0, 0.0, 0.0),
                  (0.5, 0.0, 0.0),
                  (1.0, 1.0, 1.0))} 

cmap3 = matplotlib.colors.LinearSegmentedColormap('red',cdict3,N=256)
plt.register_cmap(cmap=cmap3)

# YELLOW
cdict4 = {'red': ((0.0, 0.0, 0.0),
                 (0.5, 1.0, 1.0),
                 (1.0, 1.0, 1.0)),
	     'green': ((0.0, 0.0, 0.0),
                 (0.5, 1.0, 1.0),
                 (1.00, 1.0, 1.0)),
         'blue': ((0.0, 0.0, 0.0),
                  (0.5, 0.0, 0.0),
                  (1.0, 1.0, 1.0))} 

cmap4 = matplotlib.colors.LinearSegmentedColormap('yellow',cdict4,N=256)
plt.register_cmap(cmap=cmap4)

# GREEN
cdict5 = {'red': ((0.0, 0.0, 0.0),
                 (0.5, 0.0, 0.0),
                 (1.0, 1.0, 1.0)),
	     'green': ((0.0, 0.0, 0.0),
                  (0.5, 1.0, 1.0),
                 (1.00, 1.0, 1.0)),
         'blue': ((0.0, 0.0, 0.0),
                  (0.5, 0.0, 0.0),
                  (1.0, 1.0, 1.0))} 

cmap5 = matplotlib.colors.LinearSegmentedColormap('green',cdict5,N=256)
plt.register_cmap(cmap=cmap5)

# CYAN
cdict6 = {'red': ((0.0, 0.0, 0.0),
                 (0.5, 0.0, 0.0),
                 (1.0, 1.0, 1.0)),
	     'green': ((0.0, 0.0, 0.0),
                 (0.5, 1.0, 1.0),
                 (1.00, 1.0, 1.0)),
         'blue': ((0.0, 0.0, 0.0),
                  (0.5, 1.0, 1.0),
                  (1.0, 1.0, 1.0))} 

cmap6 = matplotlib.colors.LinearSegmentedColormap('cyan',cdict6,N=256)
plt.register_cmap(cmap=cmap6)

# PURPLE
cdict7 = {'red': ((0.0, 0.0, 0.0),
                 (0.5, 0.5, 0.5),
                 (1.0, 1.0, 1.0)),
	     'green': ((0.0, 0.0, 0.0),
                 (0.5, 0.0, 0.0),
                 (1.00, 1.0, 1.0)),
         'blue': ((0.0, 0.0, 0.0),
                  (0.5, 1.0, 1.0),
                  (1.0, 1.0, 1.0))} 

cmap7 = matplotlib.colors.LinearSegmentedColormap('purple',cdict7,N=256)
plt.register_cmap(cmap=cmap7)

# HIGH CONTRAST
x = 5.0
cdict8 = {'red': ((0.0, pow(0.0,x), pow(0.0,x)),
                 (0.10, pow(0.1,x), pow(0.1,x)),
                 (0.20, pow(0.2,x), pow(0.2,x)),
                 (0.30, pow(0.3,x), pow(0.3,x)),
                 (0.40, pow(0.4,x), pow(0.4,x)),
                 (0.50, pow(0.5,x), pow(0.5,x)),
                 (0.60, pow(0.6,x), pow(0.6,x)),
                 (0.70, pow(0.7,x), pow(0.7,x)),
                 (0.80, pow(0.8,x), pow(0.8,x)),
                 (0.90, pow(0.9,x), pow(0.9,x)),
                 (1.00, pow(1.0,x), pow(1.0,x))),
        'green': ((0.0, pow(0.0,x), pow(0.0,x)),
                 (0.10, pow(0.1,x), pow(0.1,x)),
                 (0.20, pow(0.2,x), pow(0.2,x)),
                 (0.30, pow(0.3,x), pow(0.3,x)),
                 (0.40, pow(0.4,x), pow(0.4,x)),
                 (0.50, pow(0.5,x), pow(0.5,x)),
                 (0.60, pow(0.6,x), pow(0.6,x)),
                 (0.70, pow(0.7,x), pow(0.7,x)),
                 (0.80, pow(0.8,x), pow(0.8,x)),
                 (0.90, pow(0.9,x), pow(0.9,x)),
                 (1.00, pow(1.0,x), pow(1.0,x))),
         'blue': ((0.0, pow(0.0,x), pow(0.0,x)),
                 (0.10, pow(0.1,x), pow(0.1,x)),
                 (0.20, pow(0.2,x), pow(0.2,x)),
                 (0.30, pow(0.3,x), pow(0.3,x)),
                 (0.40, pow(0.4,x), pow(0.4,x)),
                 (0.50, pow(0.5,x), pow(0.5,x)),
                 (0.60, pow(0.6,x), pow(0.6,x)),
                 (0.70, pow(0.7,x), pow(0.7,x)),
                 (0.80, pow(0.8,x), pow(0.8,x)),
                 (0.90, pow(0.9,x), pow(0.9,x)),
                 (1.00, pow(1.0,x), pow(1.0,x)))}

cmap8 = matplotlib.colors.LinearSegmentedColormap('highcont',cdict8,N=256)
plt.register_cmap(cmap=cmap8)


if colours == 1: 
	colours=str('gray')
elif colours == 2: 
	colours=str('hot')
elif colours == 3: 
	colours=cmap1
elif colours == 4: 
	colours=cmap2
elif colours == 5: 
	colours=cmap3
elif colours == 6:
	colours=cmap4
elif colours == 7: 
	colours=cmap5
elif colours == 8: 
	colours=cmap6
elif colours == 9: 
	colours=cmap7
elif colours == 10: 
	colours=cmap8
elif colours == 11:
	colours=str('jet')
elif colours == 12:
	colours=str('RdBu')
elif colours == 13:
	colours=str('cool')
elif colours == 14:
	colours=str('RdYlBu')
elif colours == 15:
	colours=str('RdYlGn')
elif colours == 16:
	colours=str('spectral')
	
	
		
# Disable interactive plotting in matplotlib		
plt.ioff()


# Get base size
basefile = pyfits.open('Channel_000.fits', mode='update')
baseimage = basefile[0].data

nz = baseimage.shape[0]
ny = baseimage.shape[1]
nx = baseimage.shape[2]

basefile.close()
	
					
minv = Smin
maxv = Smax


	
def plotxy(i):		
	# Render the channels
	clearscreen()
		
	print 'Rendering channels...'+str(i).zfill(3)+'/'+str(maxchan).zfill(3) 
	
	fitsfile = pyfits.open(basename+str(i).zfill(3)+'.fits', mode='update')
		
	image = fitsfile[0].data
	
	# Converted original cube with miriad, which always outputs two channels even though we only
	# want one.		
	imslice=image[0,:,:]
						
						
	fig=plt.figure()	
	# This works since dpi=100, so images will have correct size and shape
	fig.set_size_inches(float(nx)/100.0,float(ny)/100.0)
	# Trick to remove any axes and whitespace from the images
	ax=plt.Axes(fig, [0.,0.,1.,1.,])
	ax.set_axis_off()
	fig.add_axes(ax)
			
	plt.set_cmap(colours)
	if logs==0:	
		try:
			ax.imshow(imslice,aspect='auto',interpolation='none',vmin=minv,vmax=maxv)
		except:
			ax.imshow(imslice,aspect='auto',interpolation='nearest',vmin=minv,vmax=maxv)
	if logs==1:
		try:
			ax.imshow(imslice,aspect='auto',interpolation='none',vmin=minv,vmax=maxv,norm=matplotlib.colors.LogNorm())
		except:
			ax.imshow(imslice,aspect='auto',interpolation='nearest',vmin=minv,vmax=maxv,norm=matplotlib.colors.LogNorm())
	plt.gca().invert_yaxis()
				
	# Format numbers so have trailing zeros (ensures compatibility with kvis)
	n = str(i+1).zfill(3)
			

	filename = str(dname)+'/Images/'+'Channel_'+str(n)+str('.png')
					
	plt.savefig(filename,dpi=100,facecolor='black')
	plt.close(fig)
	
	fitsfile.close()
		

		


		
if __name__ == '__main__':		

	pool = multiprocessing.Pool()
	


	# Create a subdirectory to hold the files, if it doesn't already exist
	if os.path.isdir('./Images') == False:
		os.system('mkdir Images')

	pool.map(plotxy, range(minchan,maxchan))
	print 'Done !'
