import pylab as p

import sys
import time 
import numpy as n
import pyfits as pf
import glob
from scipy.interpolate import interp1d
from scipy.optimize import curve_fit
import cPickle

hd=pf.open("/home/comparat/data/DEEP2/zcat.deep2.dr4.fits")
summ=hd[1].data
hd.close()
spl=glob.glob("/home/comparat/data/DEEP2/*/*/spec1d.*.fits")
n.random.shuffle(spl)

calibDir="/home/comparat/data/DEEP2/calibFiles/"

paramsEndr=pf.open(calibDir+"paramsendr.fits")[0].data
params=pf.open(calibDir+"params.fits")[0].data

c1,c2=n.loadtxt(calibDir+"thr_go1200_80_og550.asc",unpack=True)
throughput=interp1d(c1,c2)

bdat=n.loadtxt(calibDir+"Bresponse.txt",unpack=True,usecols=(0,6))
Bresponse=interp1d(bdat[0],bdat[1])
rdat=n.loadtxt(calibDir+"Rresponse.txt",unpack=True,usecols=(0,6))
Rresponse=interp1d(rdat[0],rdat[1])
idat=n.loadtxt(calibDir+"Iresponse.txt",unpack=True,usecols=(0,6))
Iresponse=interp1d(idat[0],idat[1])

fun=lambda x,a,b : a*x+b

class Spec1D:
	def __init__(self,name,summ):
		self.name=name
		nameCut=self.name.split('/')[-1].split('.')
		self.mask=int(nameCut[1])
		self.slit=int(nameCut[2])
		self.ID=int(nameCut[3])
		sel=(summ['MASK']==self.mask)&(summ['SLIT']==self.slit) &(summ['OBJNO']==self.ID )
		self.data=summ[sel]			
		hdS=pf.open(name)
		# blue spectrum
		self.dB=hdS[1].data
		# red pectrum
		self.dR=hdS[2].data
		self.chipNO=hdS[1].header['CHIPNO']-1
		hdS.close()
		lb=n.hstack((self.dB['LAMBDA'][0],self.dR['LAMBDA'][0]))
		self.lambdSwitch=n.max(self.dB['LAMBDA'][0])
		self.pixSampled=n.arange(2*4096)[(lb>6000)&(lb<10000)]
		self.lambd=lb[(lb>6000)&(lb<10000)]
		self.spec=n.hstack((self.dB['SPEC'][0],self.dR['SPEC'][0]))[(lb>6000)&(lb<10000)]
		self.ivar=n.hstack((self.dB['IVAR'][0],self.dR['IVAR'][0]))[(lb>6000)&(lb<10000)]
		self.specErr=self.ivar**(-0.5)

	def correctQE(self):
		if self.lambd.max()-self.lambd.min() > 3000 or n.mean(self.lambd)<7300 or n.mean(self.lambd)>8300 :
			print "cannot QE correct"

		xravg = 8900
		yravg = 150
		correctionavg = paramsEndr[0] + paramsEndr[1] * self.lambd
		self.xavg = (self.lambd - xravg)/yravg 
		ok1 =  (self.xavg > 0) & ( self.xavg < 1)
		self.cor2avg = correctionavg*self.xavg + 1*(1-self.xavg)
		ok2=(ok1)&(self.cor2avg>1)
		self.cor2avg[(ok2==False)] = n.ones_like(self.cor2avg[(ok2==False)])
		
		#npixel=len(self.lambd)
		self.left=(self.lambd<=self.lambdSwitch) # n.arange(4096)
		self.right=(self.lambd>self.lambdSwitch) # n.arange(4096,4096*2,1)

		#xx_b=self.lambd[self.left]
		#xx_r=self.lambd[self.right]

		#corr_b = params[num,0] + params[num,1]*self.lambd[self.left] + params[num,2]*self.lambd[self.left]**2
		#corr_r = params[num+4,0] + params[num+4,1]*self.lambd[self.right] + params[num+4,2]*self.lambd[self.right]**2
		corr_b = 1./( params.T[self.chipNO][0] + params.T[self.chipNO][1]*self.lambd[self.left] + params.T[self.chipNO][2]*self.lambd[self.left]**2 )
		corr_r = 1./( params.T[self.chipNO+4][0] + params.T[self.chipNO+4][1]*self.lambd[self.right] + params.T[self.chipNO+4][2]*self.lambd[self.right]**2 )
		# print corr_b, corr_r, self.cor2avg
		# print "spectrum",self.spec

		self.specTMP=n.zeros_like(self.spec)
		self.specErrTMP=n.zeros_like(self.specErr)
		self.ivarTMP=n.zeros_like(self.ivar)

		self.specTMP[self.left]=self.spec[self.left]*corr_b
		self.specTMP[self.right]=self.spec[self.right]*corr_r* self.cor2avg[self.right]

		self.specErrTMP[self.left]=self.specErr[self.left]*corr_b
		self.specErrTMP[self.right]=self.specErr[self.right]*corr_r* self.cor2avg[self.right]

		self.ivarTMP[self.left]=self.ivar[self.left]/(corr_b*corr_b)
		self.ivarTMP[self.right]=self.ivar[self.right]/(corr_r*corr_r* self.cor2avg[self.right]*self.cor2avg[self.right] )

		self.specTMP=self.specTMP/throughput.y[self.pixSampled]
		self.specErrTMP=self.specErrTMP/throughput.y[self.pixSampled]
		self.ivarTMP=self.ivarTMP*throughput.y[self.pixSampled]**2

	def fluxCal(self):
		countr = n.sum(self.specTMP*Rresponse(self.lambd))/ n.sum(Rresponse( self.lambd))
		counti = n.sum(self.specTMP*Iresponse(self.lambd))/n.sum( Iresponse( self.lambd))
		# (in erg/s/cm^2/Hz)
		fluxr = 10**((self.data['MAGR'] + 48.6)/(-2.5)) 
		fluxi = 10**((self.data['MAGI'] + 48.6)/(-2.5))
		fpcr = fluxr / countr
		fpci = fluxi / counti
		effr = 6599.0889
		effi = 8135.4026
		x = [effr, effi]
		y = [fpcr[0], fpci[0]]
		# print x, y
		if y[0]>0 and y[1]>0:
			pfits = curve_fit(fun,n.log(x),n.log(y),p0=(-0.01,-68))
			fluxn_corr = n.e**( pfits[0][1] + n.log(self.lambd)*pfits[0][0] )
		elif y[0]>0 and y[1]<0:
			fluxn_corr=fpcr
		elif y[0]<0 and y[1]>0:
			fluxn_corr=fpci
		else :
			return "bad"

		self.fluxn = fluxn_corr * self.specTMP
		self.fluxnErr = fluxn_corr * self.specErrTMP
		self.ivar_fluxn=self.ivarTMP/fluxn_corr**2

		self.fluxl=self.fluxn*299792458.0 / (self.lambd**2 * 10**(-10))
		self.fluxlErr=self.fluxnErr *299792458.0 / (self.lambd**2 * 10**(-10))
		return fluxn_corr

	def writeFCspec(self):
		ff=open(self.name[:-5]+"_fc2.dat",'w')
		n.savetxt(ff,n.transpose([self.lambd,self.fluxl,self.fluxlErr]))
		ff.close()




for ii in range(len(spl)):
	print "=>",spl[ii]
	checkList=glob.glob(spl[ii][:-5]+"_fc2.dat")
	if len( checkList) >=1 or len(spl[ii])>71 :
		print "skip",spl[ii],checkList#, spec1d.slit, spec1d.mask, spec1d.ID, spec1d.data,spec1d.data['ZBEST'],spec1d.data['ZQUALITY']
		continue

	spec1d=Spec1D(spl[ii],summ)
	if spec1d.data['ZBEST']>=0 and spec1d.data['ZQUALITY']>=1 and n.max(spec1d.lambd)>6500:
		print spl[ii], spec1d.slit, spec1d.mask, spec1d.ID,time.time()
		spec1d.correctQE()
		print "========================="
		fc=spec1d.fluxCal()
		if fc=="bad":
			print "not calibrable"
		else:
			spec1d.writeFCspec()

