from scipy.optimize import curve_fit
from scipy.stats import scoreatpercentile
import numpy as n
import pyfits as pf
import astropy
import scipy as sp
import glob
import pylab as p
import sys
# Location of the emission lines of interest:
O2a=3727.092 # in A
O2b=3729.875 # in A
O2=(O2a+O2b)/2.
Hg=4102.892
Hd=4341.684
Hb=4862.683
O3a=4960.295
O3b=5008.240
Ha=6564.61
intLim4k=n.array([3750, 3950, 4050, 4150])
c=299792458.0

field="UDEEP"

path_to_phot="../catalogs/VVDS_spF02_"+field+".fits"
path_to_summary="../catalogs/intermediateCatalogs/VIPERS_"+field+"_SPECTRO_PDR1.fits"
specList=glob.glob("../spectra/F02_"+field+"/1D/*.fits")
# open summary
hdu=pf.open(path_to_phot)
summ=hdu[1].data

def readSpec(id):
    specFileName=glob.glob("../spectra/F02_"+field+"/1D/*"+str(id)+"*clean.fits")[0]
    specFileNoise=glob.glob("../spectra/F02_"+field+"/1D/*"+str(id)+"*noise.fits")[0]
    spectraHDU=pf.open(specFileName)
    c0 = spectraHDU[0].header['CRVAL1']
    c1 = spectraHDU[0].header['CDELT1']
    npix = spectraHDU[0].header['naxis1']
    wave = c0 + c1 * n.arange(npix)  # here are the wavelengths
    fl=spectraHDU[0].data[0]
    spectraHDU_N=pf.open(specFileNoise)
    c0 = spectraHDU_N[0].header['CRVAL1']
    c1 = spectraHDU_N[0].header['CDELT1']
    npix = spectraHDU_N[0].header['naxis1']
    wave_N = c0 + c1 * n.arange(npix)  # here are the wavelengths
    fl_N=spectraHDU_N[0].data[0]
    return wave, fl, fl_N

def fit_4000D(wl,spec1d,err1d,z):
    toInt=interp1d(wl,spec1d)
    left=quad(toInt,intLim4k[0],intLim4k[1])[0]
    right=quad(toInt,intLim4k[2],intLim4k[3])[0]
    toInt=interp1d(wl,spec1d+err1d)
    leftU=quad(toInt,intLim4k[0],intLim4k[1])[0]
    rightU=quad(toInt,intLim4k[2],intLim4k[3])[0]
    toInt=interp1d(wl,spec1d-err1d)
    leftL=quad(toInt,intLim4k[0],intLim4k[1])[0]
    rightL=quad(toInt,intLim4k[2],intLim4k[3])[0]
    return left, leftU, leftL, right, rightU, rightL

def fit_Line_leftC(wl,spec1d,err1d,z,a0,DLC=230):
    domainLine=(wl>a0-35)&(wl<a0+35)
    domainCont=(wl>a0-DLC)&(wl<a0-35)
    if a0>wl.min()+DLC and len(domainLine.nonzero()[0])>2 and len(domainCont.nonzero()[0])>2 :
        continu=n.median(spec1d[domainCont])
        flG=lambda aa,sigma,F0 : continu + F0*(n.e**(-(aa-a0)**2./(2.*sigma**2.)))/(sigma*(2.*n.pi)**0.5)
        out = curve_fit(flG, wl[domainLine], spec1d[domainLine], p0=n.array([10.,1e-16]),sigma=err1d[domainLine],maxfev=500000000)
        if out[1].__class__==n.ndarray : # if the fit worked
            model1=flG(wl[domainLine],out[0][0],out[0][1])
            var=err1d[domainLine]
            A2=n.array([[n.sum((model1/var)**2.),n.sum(model1/var**2.)],[n.sum(model1/var**2.),n.sum(1./var**2.)]])
            if n.linalg.det(A2)!=0:
                A1=n.linalg.inv(A2)
                SN=1/(A1[0][0])**0.5
            else: # if the fit did not work
                SN=-1
            sigma=out[0][0]
            sigmaErr=out[1][0][0]**0.5
            flux=out[0][1]
            fluxErr=out[1][1][1]**0.5
            EW=flux/continu
            toInt=interp1d(wl,spec1d)
            fluxInt=quad(toInt,a0-2*sigma,a0+2*sigma)
            return a0,flux,fluxErr,sigma,sigmaErr,SN,continu,EW #,wl[domainLine],model1
        else :
            return a0, -1,-1,-1, -1,-1,-1, -1#,-1,-1
    else :
        return a0, -1,-1,-1, -1,-1,-1, -1#,-1,-1

def fit_Line_rightC(wl,spec1d,err1d,z,a0,DLC=230):
    domainLine=(wl>a0-35)&(wl<a0+35)
    domainCont=(wl>a0+35)&(wl<a0+DLC)
    if a0<wl.max()-DLC and len(domainLine.nonzero()[0])>2 and len(domainCont.nonzero()[0])>2 :
        continu=n.median(spec1d[domainCont])
        flG=lambda aa,sigma,F0 : continu + F0*(n.e**(-(aa-a0)**2./(2.*sigma**2.)))/(sigma*(2.*n.pi)**0.5)
        out = curve_fit(flG, wl[domainLine], spec1d[domainLine], p0=n.array([10.,1e-16]),sigma=err1d[domainLine],maxfev=500000000)
        if out[1].__class__==n.ndarray : # if the fit worked
            model1=flG(wl[domainLine],out[0][0],out[0][1])
            var=err1d[domainLine]
            A2=n.array([[n.sum((model1/var)**2.),n.sum(model1/var**2.)],[n.sum(model1/var**2.),n.sum(1./var**2.)]])
            if n.linalg.det(A2)!=0:
                A1=n.linalg.inv(A2)
                SN=1/(A1[0][0])**0.5
            else: # if the fit did not work
                SN=-1
            sigma=out[0][0]
            sigmaErr=out[1][0][0]**0.5
            flux=out[0][1]
            fluxErr=out[1][1][1]**0.5
            EW=flux/continu
            return a0,flux,fluxErr,sigma,sigmaErr,SN,continu,EW #,wl[domainLine],model1
        else :
            return a0, -1,-1,-1, -1,-1,-1, -1#,-1,-1
    else :
        return a0, -1,-1,-1, -1,-1,-1, -1#,-1,-1


output=[]
for index in range(len(summ['Z'])):
    redshift=summ['Z'][index]
    zflag=summ['ZFLAGS'][index]
    if zflag>=2. and zflag<=9 and redshift>0.:
        wl,fl,flErr=readSpec(summ['NUM'][index])
        mask=(fl>flErr)
        if  len(wl[(mask==0)])>10:
            print redshift,zflag,summ['NUM'][index]
            NN=summ['NUM'][index]
            out=n.hstack((NN,fit_Line_leftC(wl[(mask==0)],fl[(mask==0)],flErr[(mask==0)],redshift,O2*(1+redshift)),fit_Line_leftC(wl[(mask==0)],fl[(mask==0)],flErr[(mask==0)],redshift,Hb*(1+redshift)),fit_Line_leftC(wl[(mask==0)],fl[(mask==0)],flErr[(mask==0)],redshift,O3a*(1+redshift),DLC=130),fit_Line_rightC(wl[(mask==0)],fl[(mask==0)],flErr[(mask==0)],redshift,O3b*(1+redshift)),fit_Line_leftC(wl[(mask==0)],fl[(mask==0)],flErr[(mask==0)],redshift,Ha*(1+redshift),DLC=130)))
            output.append(out)


head=" id a0_Oii flux_Oii fluxErr_Oii sigma_Oii sigmaErr_Oii SN_Oii continu_Oii EW_Oii a0_Hb flux_Hb fluxErr_Hb sigma_Hb sigmaErr_Hb SN_Hb continu_Hb EW_Hb a0_O3a flux_O3a fluxErr_O3a sigma_O3a sigmaErr_O3a SN_O3a continu_O3a EW_O3a a0_O3b flux_O3b fluxErr_O3b sigma_O3b sigmaErr_O3b SN_O3b continu_O3b EW_O3b a0_Ha flux_Ha fluxErr_Ha sigma_Ha sigmaErr_Ha SN_Ha continu_Ha EW_Ha "

n.savetxt("vvds-fluxes-"+field+".dat",output,header=head)

# a0_Oii flux_Oii fluxErr_Oii sigma_Oii sigmaErr_Oii SN_Oii continu_Oii EW_Oii
# a0_Hb flux_Hb fluxErr_Hb sigma_Hb sigmaErr_Hb SN_Hb continu_Hb EW_Hb
# a0_O3a flux_O3a fluxErr_O3a sigma_O3a sigmaErr_O3a SN_O3a continu_O3a EW_O3a
# a0_O3b flux_O3b fluxErr_O3b sigma_O3b sigmaErr_O3b SN_O3b continu_O3b EW_O3b
# a0_Ha flux_Ha fluxErr_Ha sigma_Ha sigmaErr_Ha SN_Ha continu_Ha EW_Ha

sys.exit()

id2=124066271
iid=n.array([ int(el[7:]) for el in summ['id_IAU'] ])
redshift=summ['zspec'][(iid==id2)]
wl,fl,flErr=readSpec2(id2,phot,summ)

a0O3,fluxO3,fluxErrO3,sigmaO3,sigmaErrO3,SNO3,continuO3,EWO3,wlO3,modelO3=fit_O3_Line(wl[(mask==0)],fl[(mask==0)],flErr[(mask==0)],redshift)
a0O2,fluxO2,fluxErrO2,sigmaO2,sigmaErrO2,SNO2,continuO2,EWO2,wlO2,modelO2=fit_O2_Line(wl[(mask==0)],fl[(mask==0)],flErr[(mask==0)],redshift)

p.plot(wl[(mask==0)],fl[(mask==0)],'b')
p.plot(wl[(mask==0)],flErr[(mask==0)],'r')
#p.plot(wlO3,modelO3,'g')
#p.plot(wlO2,modelO2,'g')
p.yscale('log')

p.show()


