"""
Continuous Wavelet Transform.
"""

## This code is written by Davide Albanese, <albanese@fbk.eu> and
## Marco Chierici, <chierici@fbk.eu>.
## (C) 2008 Fondazione Bruno Kessler - Via Santa Croce 77, 38100 Trento, ITALY.

## See: Practical Guide to Wavelet Analysis - C. Torrence and G. P. Compo.

## This program is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 3 of the License, or
## (at your option) any later version.

## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
## GNU General Public License for more details.

## You should have received a copy of the GNU General Public License
## along with this program.  If not, see <http://www.gnu.org/licenses/>.

from numpy import *
import cwb as waveletb # wb for pure python functions
import _extend

__all__ = ["cwt", "icwt", "angularfreq", "scales", "compute_s0"]


def angularfreq(N, dt):
    """Compute angular frequencies.

    Input
    
      * *N*  - [integer] number of data samples
      * *dt* - [float] time step

    Output
    
      * *angular frequencies* - [1D numpy array float]

    """

    # See (5) at page 64.
    
    N2 = N / 2.0
    w = empty(N)

    for i in range(w.shape[0]):       
        if i <= N2:
            w[i] = (2 * pi * i) / (N * dt)
        else:
            w[i] = (2 * pi * (i - N)) / (N * dt)

    return w


def scales(N, dj, dt, s0):
    """Compute scales.

    Input
    
      * *N*  - [integer] number of data samples
      * *dj* - [float] scale resolution
      * *dt* - [float] time step

    Output
    
      * *scales* - [1D numpy array float]
    """

    #  See (9) and (10) at page 67.

    J = floor(dj**-1 * log2((N * dt) / s0))
    s = empty(J + 1)
    
    for i in range(s.shape[0]):
        s[i] = s0 * 2**(i * dj)
    
    return s


def compute_s0(dt, p, wf):
    """Compute s0.
    
    Input
    
      * *dt* - [float] time step
      * *p*  - [float] omega0 ('morlet') or order ('paul', 'dog')
      * *wf* - [string] wavelet function ('morlet', 'paul', 'dog')

    Output
    
      * *s0* - [float]
    """
    
    if wf == "dog":
        return (dt * sqrt(p + 0.5)) / pi
    elif wf == "paul":
        return (dt * ((2 * p) + 1)) / (2 * pi)
    elif wf == "morlet":
        return (dt * (p + sqrt(2 + p**2))) / (2 * pi)
    else:
        raise ValueError("wavelet '%s' is not available" % wf)


def cwt(x, dt, dj, wf="dog", p=2, extmethod='none', extlength='powerof2'):
    """Continuous Wavelet Tranform.

    :Parameters:   
      x : 1d ndarray float
        data
      dt : float
         time step
      dj : float
         scale resolution (smaller values of dj give finer resolution)
      wf : string ('morlet', 'paul', 'dog')
         wavelet function
      p : float
        wavelet function parameter
      extmethod : string ('none', 'reflection', 'periodic', 'zeros')
                indicates which extension method to use
      extlength : string ('powerof2', 'double')
                indicates how to determinate the length of the extended data
            
    :Returns:
      (X, scales) : (2d ndarray complex, 1d ndarray float)
                  transformed data, scales

    Example:

    >>> import numpy as np
    >>> import mlpy
    >>> x = np.array([1,2,3,4,3,2,1,0])
    >>> mlpy.cwt(x=x, dt=1, dj=2, wf='dog', p=2)
    (array([[ -4.66713159e-02 -6.66133815e-16j,
             -3.05311332e-16 +2.77555756e-16j,
              4.66713159e-02 +1.38777878e-16j,
              6.94959463e-01 -8.60422844e-16j,
              4.66713159e-02 +6.66133815e-16j,
              3.05311332e-16 -2.77555756e-16j,
             -4.66713159e-02 -1.38777878e-16j,
             -6.94959463e-01 +8.60422844e-16j],
           [ -2.66685280e+00 +2.44249065e-15j,
             -1.77635684e-15 -4.44089210e-16j,
              2.66685280e+00 -3.10862447e-15j,
              3.77202823e+00 -8.88178420e-16j,
              2.66685280e+00 -2.44249065e-15j,
              1.77635684e-15 +4.44089210e-16j,
             -2.66685280e+00 +3.10862447e-15j,
             -3.77202823e+00 +8.88178420e-16j]]), array([ 0.50329212,  2.01316848]))
    """

    xcopy = x.copy() - mean(x)

    if extmethod != 'none':
        xcopy = _extend.extend(xcopy, method=extmethod, length=extlength)
   
    w = angularfreq(xcopy.shape[0], dt)
    s0 = compute_s0(dt, p, wf)
    s = scales(x.shape[0], dj, dt, s0)
     
    if wf == "dog":
        wft = waveletb.dogft(s, w, p, dt, norm = True)
    elif wf == "paul":
        wft = waveletb.paulft(s, w, p, dt, norm = True)
    elif wf == "morlet":
        wft = waveletb.morletft(s, w, p, dt, norm = True)
    else:
        raise ValueError("wavelet '%s' is not available" % wf)
    
    XCOPY = empty_like(wft)
    xcopy_ft = fft.fft(xcopy)
    
    for i in range(XCOPY.shape[0]):
        XCOPY[i] = fft.ifft(xcopy_ft * wft[i])
    
    return XCOPY[:, :x.shape[0]], s


def icwt(X, dt, dj, wf = "dog", p = 2, recf = True):
    """Inverse Continuous Wavelet Tranform.

    :Parameters:
      X : 2d ndarray complex
        transformed data
      dt : float
         time step
      dj : float
         scale resolution (smaller values of dj give finer resolution)
      wf : string ('morlet', 'paul', 'dog')
         wavelet function
      p : float
        wavelet function parameter    

          * morlet : 2, 4, 6
          * paul :   2, 4, 6
          * dog :    2, 6, 10

      recf : bool
           use the reconstruction factor (:math:`C_{\delta} \Psi_0(0)`)

    :Returns:
      x : 1d ndarray float
        data

    Example:

    >>> import numpy as np
    >>> import mlpy
    >>> X = np.array([[ -4.66713159e-02 -6.66133815e-16j,
    ...                -3.05311332e-16 +2.77555756e-16j,
    ...                 4.66713159e-02 +1.38777878e-16j,
    ...                 6.94959463e-01 -8.60422844e-16j,
    ...                 4.66713159e-02 +6.66133815e-16j,
    ...                 3.05311332e-16 -2.77555756e-16j,
    ...                -4.66713159e-02 -1.38777878e-16j,
    ...                -6.94959463e-01 +8.60422844e-16j],
    ...              [ -2.66685280e+00 +2.44249065e-15j,
    ...                -1.77635684e-15 -4.44089210e-16j,
    ...                 2.66685280e+00 -3.10862447e-15j,
    ...                 3.77202823e+00 -8.88178420e-16j,
    ...                 2.66685280e+00 -2.44249065e-15j,
    ...                 1.77635684e-15 +4.44089210e-16j,
    ...                -2.66685280e+00 +3.10862447e-15j,
    ...                -3.77202823e+00 +8.88178420e-16j]])
    >>> mlpy.icwt(X=X, dt=1, dj=2, wf='dog', p=2)
    array([ -1.24078928e+00,  -1.07301771e-15,   1.24078928e+00,
             2.32044753e+00,   1.24078928e+00,   1.07301771e-15,
            -1.24078928e+00,  -2.32044753e+00])
    """  

    rf = 1.0

    if recf == True:
        if wf == "dog"    and p == 2:
            rf = 3.13568
        if wf == "dog"    and p == 6:
            rf = 1.70508
        if wf == "dog"    and p == 10:
            rf = 1.30445
        if wf == "paul"   and p == 2:
            rf = 2.08652
        if wf == "paul"   and p == 4:
            rf = 1.22253
        if wf == "paul"   and p == 6:
            rf = 0.89730
        if wf == "morlet" and p == 2:
            rf = 2.54558
        if wf == "morlet" and p == 4:
            rf = 0.92079
        if wf == "morlet" and p == 6:
            rf = 0.58470

    s0 = compute_s0(dt, p, wf)
    s = scales(X.shape[1], dj, dt, s0)
     
    # See (11), (13) at page 68
    XCOPY = empty_like(X)
    for i in range(s.shape[0]):
        XCOPY[i] = X[i] / sqrt(s[i]) 
    x = dj * dt **0.5 * sum(real(XCOPY), axis = 0) / rf
   
    return x

