# -*- coding: utf-8 -*-
import copy
import numpy
import scipy.stats
from scipy.spatial.distance import pdist
from scipy.spatial import cKDTree
from ..general.coord2K import coord2K, coord2Ksplit

from .proba2stat import proba2stat
from .pystks_variable import get_standard_order, get_standard_soft_pdf_type
from .BMEprobaMoments import BMEprobaMoments
from .BMEoptions import BMEoptions
from ..stats.mepdf import maxentpdf_gc, maxentcondpdf_gc
from ..general.valstvgx import valstv2stg,valstg2stv
from ..general.neighbours import neighbours, neighbours_index_kd
from ..mvn.pyAllMoments import pyAllMomentsNG
from ..mvn.qmc import qmc

KHS_DICT = {'k': 0, 'h': 1, 's': 2}


def _bme_posterior_pdf(
    ck, ch=None, cs=None, zh=None, zs=None,
    covmodel=None, covparam=None, covmat=None,
    order=numpy.nan, options=None,
    general_knowledge='gaussian',
    #  specific_knowledge='unknown',  
    pdfk=None, pdfh=None, pdfs=None, hk_k=None, hk_h=None, hk_s=None,
    gui_args=None):

    def __get_zs_integration_limits(zs):
        '''get soft data integration limit''' 
        ranges = []
        for zsi in zs:
            pdftype = get_standard_soft_pdf_type(zsi[0])
            if pdftype in  [1, 2]: # nl, limi, probdens
                nl = zsi[1]
                limi = zsi[2]
                ranges.append((limi[0], limi[nl[0]-1]))
            elif pdftype in [10]:
                zm = zsi[1]
                zstd = numpy.sqrt(zsi[2])
                ranges.append((zm-3*zstd, zm+3*zstd))
        ranges = numpy.array(ranges)
        return ranges.copy()

    order = get_standard_order(order)
    nk = ck.shape[0]
    nh = ch.shape[0] if ch is not None else 0
    ns = cs.shape[0] if cs is not None else 0

    x_all_split = _get_x_all_split(nk, zh, zs)
    Xh = _get_x(x_all_split, 'h')
    mean_all_split = _get_mean_all_split(x_all_split, order)
    #get cov_all_split
    if covmat is None:
        cov_all_split = _get_cov_all_split(ck, ch, cs, covmodel, covparam)
        cov_k_khs = numpy.hstack([i for i in cov_all_split[0] if i is not None])
        cov_h_khs = numpy.hstack([i for i in cov_all_split[1] if i is not None])
        cov_s_khs = numpy.hstack([i for i in cov_all_split[2] if i is not None])
        covmat = numpy.vstack([cov_k_khs, cov_h_khs, cov_s_khs])
    else:
        cov_all_split = numpy.vsplit(
            covmat, [nk, nk+nh, nk+nh+ns]
            )[:-1] #exclude final empty array
        cov_all_split = \
            [numpy.hsplit(c, [nk, nk+nh, nk+nh+ns][:-1])\
            for c in cov_all_split]

    #find cuplicated point
    if ns:
        dup_ck_cs_idx = numpy.array(
            numpy.all((ck[:, None, :] == cs[None, :, :]), axis=-1).nonzero()
            ).T
    else:
        dup_ck_cs_idx = numpy.array([[]])

    if general_knowledge == 'gaussian':
        
        def __get_fG_Xkh_each_ck(fG_Xkh_):
            def fG_Xkh_each_ck(xk):
                xk_origin_shape = xk.shape
                xk = xk.flatten()
                input_xk = numpy.vstack(
                    (xk, numpy.tile(Xh, xk.size))
                    ).T
                return fG_Xkh_(input_xk).reshape(xk_origin_shape)
            return fG_Xkh_each_ck     
        
        #fG_Xh: const
        fG_Xh = _get_multivariate_normal_pdf(
            x_all_split, mean_all_split, cov_all_split, 'h')(
                _get_x(x_all_split, 'h').T)

        def __get_fG_Xs_gvn_Xh_all_ck():
            fG_Xs_gvn_Xh = _get_multivariate_normal_pdf(
                x_all_split, mean_all_split, cov_all_split, 's_h')
            def fG_Xs_gvn_Xh_all_ck(xs):
                xs_origin_shape = xs.shape
                xs = xs.reshape(-1, xs_origin_shape[-1])
                output = fG_Xs_gvn_Xh(xs)
                return output.reshape(xs_origin_shape[:-1]+(1,))
            return fG_Xs_gvn_Xh_all_ck
        fG_Xs_gvn_Xh = __get_fG_Xs_gvn_Xh_all_ck()

        fS_Xs = _get_fs(zs)
        def __qmc_int_fG_Xs_gvn_Xh__fS_Xs(x_array):
            return fG_Xs_gvn_Xh(x_array) * fS_Xs(x_array)
        zs_limits = __get_zs_integration_limits(zs)
        xmin = zs_limits[:,0].copy()
        xmax = zs_limits[:,1].copy()

        int_fG_Xs_gvn_Xh__fS_Xs, e, info = qmc(
            __qmc_int_fG_Xs_gvn_Xh__fS_Xs,
            xmin, xmax, relerr=options[3,0],pow2min=10,
            showinfo=options['qmc_showinfo']
            )
        if int_fG_Xs_gvn_Xh__fS_Xs == 0: # NC is zero
            print 'warning: normolized constant is equal to 0.'
        import pdb
        pdb.set_trace()
        #int_fG_Xs_gvn_Xh__fS_Xs: const

        def __get_fSk_Xk_each_ck(zsk):
            def fSk_Xk_each_ck(xk):
                if zsk is None:
                    return numpy.ones(xk.shape)
                else:
                    xk_origin_shape = xk.shape
                    xk = xk.flatten()
                    pdf_type = get_standard_soft_pdf_type(zsk[0])
                    if pdf_type == 2:
                        nl = zsk[1][0]
                        limi = zsk[2]
                        probdens = zsk[3]
                        y_i = numpy.interp(
                            xk, limi[:nl], probdens[:nl],
                            left = 0., right = 0.)
                    elif pdf_type == 1:
                        nl = zs[1]
                        limi = zs[2]
                        probdens = zs[3]
                        idd = numpy.where(xk - limi[:nl] > 0)[0]
                        y_i = probdens[idd]
                    elif pdf_type == 10:
                        zm = zsk[1]
                        zstd = numpy.sqrt(zsk[2])
                        try:
                            y_i = scipy.stats.norm.pdf(
                                xk, loc=zm, scale=zstd)
                        except FloatingPointError:
                            import pdb
                            pdb.set_trace()
                      
                    return y_i.reshape(xk_origin_shape)
            return fSk_Xk_each_ck

        def __get_fG_Xs_gvn_Xkh_each_ck(ck_i, cs, zs,
            x_all_split_each_ck, mean_all_split_each_ck, cov_all_split_each_ck):
            idx_result = numpy.where(numpy.all(ck_i == cs, axis=1))[0]
            if idx_result.size == 0:
                x_all_split_each_ck_dup = x_all_split_each_ck
                mean_all_split_each_ck_dup = mean_all_split_each_ck
                cov_all_split_each_ck_dup = cov_all_split_each_ck
            elif idx_result.size == 1:
                x_all_split_each_ck_dup =\
                    x_all_split_each_ck[:2] +\
                    [numpy.delete(x_all_split_each_ck[2], idx_result, axis=0)]
                mean_all_split_each_ck_dup =\
                    mean_all_split_each_ck[:2] +\
                    [numpy.delete(mean_all_split_each_ck[2], idx_result, axis=0)]
                cov_all_split_each_ck_dup = copy.deepcopy(cov_all_split_each_ck)
                cov_all_split_each_ck_dup[0][2] =\
                    numpy.delete(cov_all_split_each_ck[0][2], idx_result, axis=1)
                cov_all_split_each_ck_dup[1][2] =\
                    numpy.delete(cov_all_split_each_ck[1][2], idx_result, axis=1)
                cov_all_split_each_ck_dup[2][0] =\
                    numpy.delete(cov_all_split_each_ck[2][0], idx_result, axis=0)
                cov_all_split_each_ck_dup[2][1] =\
                    numpy.delete(cov_all_split_each_ck[2][1], idx_result, axis=0)
                cov_all_split_each_ck_dup[2][2] =\
                    numpy.delete(cov_all_split_each_ck[2][2], idx_result, axis=0)
                #be careful below
                cov_all_split_each_ck_dup[2][2] =\
                    numpy.delete(cov_all_split_each_ck_dup[2][2], idx_result, axis=1)
            elif idx_result.size > 1: #strange
                raise ValueError('ck match cs twice. (strange)')
            def __get_fG_Xs_gvn_Xkh_each_ck_eack_xk(xk):
                fG_Xs_gvn_Xkh_container = []
                xk_origin_shape = xk.shape
                xk = xk.flatten()
                for xk_i in xk:
                    x_all_split_each_ck_dup[0] = numpy.array([[xk_i]])
                    fG_Xs_gvn_Xkh_ = _get_multivariate_normal_pdf(
                    x_all_split_each_ck_dup,
                    mean_all_split_each_ck_dup,
                    cov_all_split_each_ck_dup, 's_kh')
                    def fG_Xs_gvn_Xkh_each_ck_eack_xk(xs, ff):
                        xs_origin_shape = xs.shape
                        xs = xs.reshape(-1, xs_origin_shape[-1])
                        output = ff(xs)
                        return output.reshape(xs_origin_shape[:-1]+(1,))
                    fG_Xs_gvn_Xkh_container.append(
                        lambda xs, ff=fG_Xs_gvn_Xkh_: fG_Xs_gvn_Xkh_each_ck_eack_xk(xs, ff))
                return numpy.array(
                    fG_Xs_gvn_Xkh_container).reshape(xk_origin_shape)
            return __get_fG_Xs_gvn_Xkh_each_ck_eack_xk

        def __get_fS_Xs_dup(ck_i, cs, zs):
            idx_result = numpy.where(numpy.all(ck_i == cs, axis=1))[0]
            if idx_result.size == 0:
                return fS_Xs
            elif idx_result.size == 1:
                return _get_fs(
                    [zs_i for i, zs_i in enumerate(zs) if i != idx_result])
            elif idx_result.size > 1: #strange
                raise ValueError('ck match cs twice. (strange)')

        def __get_pdf_each_ck(i):
            ck_i = ck[i]
            idx_result = numpy.where(numpy.all(ck_i == cs, axis=1))[0]
            if idx_result.size == 0:
                zs_dup = zs
            elif idx_result.size == 1:
                zs_dup = [zs_i for ii, zs_i in enumerate(zs) if ii != idx_result]
            elif idx_result.size > 1: #strange
                raise ValueError('ck match cs twice. (strange)')
            def _fK_Xk(xk):
                xk_origin_shape = xk.shape
                xk = xk.flatten()
                def __qmc_int_fG_Xs_gvn_Xkh_dup__fS_Xs_dup(x_array):
                    G_Xs_gvn_Xkh_dup_i = numpy.hstack([fi(x_array) for fi in fG_Xs_gvn_Xkh_dup[i](xk)])
                    fS_Xs_dup_i = fS_Xs_dup[i](x_array)
                    return G_Xs_gvn_Xkh_dup_i * fS_Xs_dup_i
                zs_limits = __get_zs_integration_limits(zs_dup)
                xmin = zs_limits[:,0].copy()
                xmax = zs_limits[:,1].copy()

                int_fG_Xs_gvn_Xkh_dup__fS_Xs_dup , e, info = qmc(
                    __qmc_int_fG_Xs_gvn_Xkh_dup__fS_Xs_dup,
                    xmin, xmax, relerr=options[3,0],
                    showinfo=options['qmc_showinfo']
                    )
                int_fG_Xs_gvn_Xkh_dup__fS_Xs_dup =\
                    int_fG_Xs_gvn_Xkh_dup__fS_Xs_dup.reshape(xk_origin_shape)
                xk = xk.reshape(xk_origin_shape)
                
                return (fG_Xkh[i](xk) * fSk_Xk[i](xk)
                    * int_fG_Xs_gvn_Xkh_dup__fS_Xs_dup
                    / fG_Xh / int_fG_Xs_gvn_Xh__fS_Xs)
            return _fK_Xk

        cov_hs_range = range(nk, nk+nh+ns)
        
        fG_Xkh = [] # a list contains each ck's fG_Xkh
        fSk_Xk = [] # a list contains each ck's fSk_Xk
        fS_Xs_dup = []
        fG_Xs_gvn_Xkh_dup = [] # a list contains each ck's fG_Xs_gvn_Xkh_dup
        fK_Xk = []
        for i in xrange(nk):
            #get x/mean/cov_all_split at each ck point
            x_all_split_each_ck =\
                [x_all_split[0][i:i+1,:]] + x_all_split[1:]
            mean_all_split_each_ck =\
                [mean_all_split[0][i:i+1,:]] + mean_all_split[1:]
            covmat_each_ck =\
                covmat[numpy.ix_(
                    [i]+cov_hs_range,
                    [i]+cov_hs_range
                    )]
            cov_all_split_each_ck = numpy.vsplit(
                covmat_each_ck, [1, 1+nh, 1+nh+ns]
                )[:-1] #exclude final empty array
            cov_all_split_each_ck = \
                [numpy.hsplit(c, [1, 1+nh, 1+nh+ns][:-1])\
                for c in cov_all_split_each_ck]

            # get fG_Xkh each ck part
            fG_Xkh_ = _get_multivariate_normal_pdf(
                x_all_split_each_ck,
                mean_all_split_each_ck,
                cov_all_split_each_ck, 'kh')
            fG_Xkh.append(
                __get_fG_Xkh_each_ck(fG_Xkh_))

            # get fSk_Xk
            idx_result = numpy.where(numpy.all(ck[i] == cs, axis=1))[0]
            if idx_result.size == 0:
                fSk_Xk.append(__get_fSk_Xk_each_ck(None))
            elif idx_result.size == 1:
                fSk_Xk.append(__get_fSk_Xk_each_ck(zs[idx_result[0]]))
            elif idx_result.size > 1: #strange
                raise ValueError('ck match cs twice. (strange)')

            # get fS_Xs_dup
            fS_Xs_dup.append(__get_fS_Xs_dup(ck[i], cs, zs))

            # get fG_Xs_gvn_Xkh
            fG_Xs_gvn_Xkh_dup.append(
                __get_fG_Xs_gvn_Xkh_each_ck(ck[i], cs, zs,
                    x_all_split_each_ck,
                    mean_all_split_each_ck,
                    cov_all_split_each_ck))

            fK_Xk.append(__get_pdf_each_ck(i))

        return numpy.array([fK_Xk]).reshape((-1, 1))
    else: #general knowledge is not gaussian
        pass

def _bme_posterior_moments(
    ck, ch=None, cs=None, zh=None, zs=None,
    covmodel=None, covparam=None, covmat=None,
    order=numpy.nan, options=None,
    general_knowledge='gaussian',
    #  specific_knowledge='unknown',  
    pdfk=None, pdfh=None, pdfs=None, hk_k=None, hk_h=None, hk_s=None,
    gui_args=None):

    '''
        no neighbour considered, so spatial-temporal range
        should be transform first (no dmax support).

        covmat:
            covariance matrix, a numpy 2d array
            with shape (nk+nh+ns) by (nk+nh+ns)
        if covmat provieded, covmodel and covparam are simply skipped.
    '''

    if general_knowledge == 'gaussian':
        if zs:
            all_zs_type = numpy.array(
                map(get_standard_soft_pdf_type, [zsi[0] for zsi in zs])
                )
            if (all_zs_type==10).all(): # all soft type are gaussian
                mvs = _bme_proba_gaussian(
                    ck, ch, cs, zh, zs,
                    covmodel, covparam, covmat, order)
                return mvs
            else: # has non-gaussian
                nk = ck.shape[0]
                nh = ch.shape[0]
                ns = cs.shape[0]
                x_all_split = _get_x_all_split(nk, zh, zs)
                mean_all_split = _get_mean_all_split(x_all_split, order)
                if covmat is not None:
                    cov_all_split = numpy.vsplit(
                        covmat, [nk, nk+nh, nk+nh+ns]
                        )[:-1] #exclude final empty array
                    cov_all_split = \
                        [numpy.hsplit(c, [nk, nk+nh, nk+nh+ns][:-1])\
                        for c in cov_all_split]
                else:
                    cov_all_split = _get_cov_all_split(
                        ck, ch, cs, covmodel, covparam)

                fg_s_given_h = _get_multivariate_normal_pdf(
                    x_all_split, mean_all_split,
                    cov_all_split, 's_h')
                fs = _get_fs(zs)

                #split hard and soft of m_k_gvn_hs
                m_k = _get_mean(mean_all_split, 'k')
                Bm_sigma_inv_multi = (
                    _get_sigma(cov_all_split, 'k', 'hs').dot(
                        _get_sigma(cov_all_split, 'hs', 'hs', inv=True))
                    )
                m_hs = _get_mean(mean_all_split, 'hs')
                m_k_gvn_hs_part = m_k - Bm_sigma_inv_multi.dot(m_hs)
                x_hs = _get_x(x_all_split, 'hs') #put true Xs later

                sigma_k_given_hs = _get_sigma_a_given_b(
                    cov_all_split, 'k', 'hs')
                diag_sigma_k_given_hs =\
                    numpy.diag(sigma_k_given_hs)

                def func_moments(x_array):
                    nMon = 3
                    npts = x_array.shape[0]
                    res = numpy.empty(
                        (npts, nMon*nk + 1)
                        )
                    x_hs_npts = numpy.tile(x_hs, (1, npts))
                    x_hs_npts[nh:,:] = x_array.T # put true Xs part

                    m_k_gvn_hs_npts = (
                        m_k_gvn_hs_part + Bm_sigma_inv_multi.dot(x_hs_npts)
                        ).T

                    fg_fs = (
                        fg_s_given_h(x_array).reshape((npts,1)) * fs(x_array)
                        )
                    res[:, 0*nk:1*nk] = m_k_gvn_hs_npts * fg_fs
                    res[:, 1*nk:2*nk] = (res[:,:nk] * m_k_gvn_hs_npts)
                    res[:, 2*nk:3*nk] = (
                        3 * diag_sigma_k_given_hs * res[:,:nk]
                        - 2 * res[:,nk:2*nk] * m_k_gvn_hs_npts
                        )
                    res[:,-1:] = fg_fs
                    return res

                ranges = []
                for zsi in zs:
                    pdftype = get_standard_soft_pdf_type(zsi[0])
                    if pdftype in  [1, 2]: # nl, limi, probdens
                        nl = zsi[1]
                        limi = zsi[2]
                        ranges.append((limi[0], limi[nl[0]-1]))
                    elif pdftype in [10]:
                        zm = zs[1]
                        zstd = numpy.sqrt(zs[2])
                        ranges.append((zm-5*zstd, zm+5*zstd))
                ranges = numpy.array(ranges)
                xmin = ranges[:,0].copy()
                xmax = ranges[:,1].copy()

                Mon, e, info = qmc(
                    func_moments, xmin, xmax, relerr=options[3,0],
                    showinfo=options['qmc_showinfo'])

                Mon_NC = Mon[-1]
                Mon = Mon[:-1].reshape((-1, nk)).T # k by nmon
                Mon /= Mon_NC # for moments 1,2,3
                Mon[:,1:2] += (
                    diag_sigma_k_given_hs.reshape((-1, 1))
                    ) # for moments 2

                M1 = Mon[:, 0:1]
                M2 = Mon[:, 1:2]
                M3 = Mon[:, 2:3]

                mvs = numpy.empty(Mon.shape)
                mvs[:,0:1] = M1
                mvs[:,1:2] = M2 - M1**2
                mvs[:,2:3] = M3 - 3*M2*M1 + 2*M1**3

                return mvs
        else:
            mvs = _bme_proba_gaussian(
                ck, ch, cs, zh, zs, covmodel, covparam, covmat, order)
            return mvs
    else:
        raise ValueError("Now we can not consider non-gaussian GK." )

def _bme_proba_gaussian(
    ck, ch=None, cs=None, zh=None, zs=None,
    covmodel=None, covparam=None, covmat=None,
    order=numpy.nan, ck_cov_output=False):
    '''
    no neighbour consider, no data format transform.

    ch, cs, zh: np.2darray or None
    zs: new zs data or None, see softconverter.py for detail.
    ck_cov_output: if True, result will additionally return
        covariance between ck
    NOTE zs there should gaussian type e.g.
        zs = ((10, mean1, var1),...,(10, mean2, var2))
    '''
    # ch = numpy.array([], ndmin=2).reshape((0, ck.shape[1])) if ch is None
    # zh = numpy.array([], ndmin=2).reshape((0, 1)) if zh is None
    # cs = numpy.array([], ndmin=2).reshape((0, ck.shape[1])) if cs is None

    # order = get_standard_order(order)
    
    nk = ck.shape[0]
    nh = ch.shape[0] if ch is not None else 0
    ns = cs.shape[0] if cs is not None else 0

    x_all_split = _get_x_all_split(nk, zh, zs)
    mean_all_split = _get_mean_all_split(x_all_split, order)
    if covmat is not None:
        cov_all_split = numpy.vsplit(
            covmat, [nk, nk+nh, nk+nh+ns]
            )[:-1] #exclude final empty array
        cov_all_split = \
            [numpy.hsplit(c, [nk, nk+nh, nk+nh+ns][:-1])\
            for c in cov_all_split]
    else:
        cov_all_split =\
            _get_cov_all_split(ck, ch, cs, covmodel, covparam)

    if ns == 0 and nh == 0:
        mvs = numpy.empty((ck.shape[0],3))
        mvs[:] = numpy.nan
        return mvs
        #raise ValueError('hard and soft data can not both without input')

    if ns == 0: # only hard data
        mean_k_given_h = _get_mean_a_given_b(
            x_all_split, mean_all_split,
            cov_all_split, sub_a='k', sub_b='h')
        sigma_k_given_h = _get_sigma_a_given_b(
            cov_all_split, sub_a='k', sub_b='h')

        skewness = numpy.zeros(mean_k_given_h.shape)
        mvs = numpy.hstack(
            (mean_k_given_h, sigma_k_given_h.diagonal().reshape((-1,1)),
            skewness)
            )
        if ck_cov_output:
            return mvs, sigma_k_given_h
        else:
            return mvs
    else: # both hard and soft data (hard data can be empty)
        mean_k = _get_mean(mean_all_split, 'k')
        mean_hs = _get_mean(mean_all_split, 'hs')
        mean_s_given_h = _get_mean_a_given_b(
            x_all_split, mean_all_split,
            cov_all_split, sub_a='s', sub_b='h')
        NC, useful_args = _get_int_fg_a_given_b_fs_s(
            x_all_split, mean_all_split,
            cov_all_split, zs, 
            sub_multi = 's_h', sub_s='s'
            )
        (sigma_t_prime, inv_sigma_s_given_h,
        inv_sigma_tilde_s, mean_tilde_s) = useful_args
        hat_x_s = NC * sigma_t_prime.dot(
            inv_sigma_s_given_h.dot(mean_s_given_h)
            + inv_sigma_tilde_s.dot(mean_tilde_s)
            )
        if not nh:
            hat_x_hs = hat_x_s
        else:
            hat_x_h = _get_x(x_all_split, 'h') * NC
            hat_x_hs = numpy.vstack((hat_x_h, hat_x_s))
        sigma_k_hs = _get_sigma(
            cov_all_split, sub_a='k', sub_b='hs')
        inv_sigma_hs_hs = _get_sigma(
            cov_all_split, sub_a='hs', sub_b='hs', inv=True)
        cond_k_hs = sigma_k_hs.dot(inv_sigma_hs_hs)
        BME_mean_k_given_hs_a = cond_k_hs.dot(hat_x_hs)
        BME_mean_k_given_hs_b = mean_k - cond_k_hs.dot(mean_hs)
        BME_mean_k_given_hs =\
            BME_mean_k_given_hs_a/NC + BME_mean_k_given_hs_b

        sigma_k_given_hs = _get_sigma_a_given_b(
            cov_all_split, sub_a='k', sub_b='hs')
        mean_t = hat_x_hs/NC
        aa = numpy.zeros(inv_sigma_hs_hs.shape)
        aa[nh:nh+ns, nh:nh+ns] = sigma_t_prime*NC
        bb = (mean_t - mean_hs).dot((mean_t - mean_hs).T) * NC
        tt = cond_k_hs.dot(aa + bb).dot(cond_k_hs.T)
        
        if not ck_cov_output:
            sigma_k_given_hs_diag =\
                sigma_k_given_hs.diagonal().reshape((-1,1))
            tt_diag = tt.diagonal().reshape((-1,1))
            BME_var_k_given_hs = (
                sigma_k_given_hs_diag - BME_mean_k_given_hs**2 + mean_k**2
                - 2*mean_k * cond_k_hs.dot(mean_hs)
                + 2*mean_k * cond_k_hs.dot(hat_x_hs) / NC
                + tt_diag / NC
                )
        else:
            exm = cond_k_hs.dot(hat_x_hs).dot(mean_k.T)
            emm = cond_k_hs.dot(mean_hs).dot(mean_k.T)
            BME_var_k_given_hs_cov = (
                sigma_k_given_hs
                - BME_mean_k_given_hs.dot(BME_mean_k_given_hs.T)
                + mean_k.dot(mean_k.T)
                + 2*exm - 2*emm
                + tt / NC
                )
            BME_var_k_given_hs =\
                BME_var_k_given_hs_cov.diagonal().reshape((-1,1))
            
        skewness = numpy.zeros(BME_mean_k_given_hs.shape)
        mvs = numpy.hstack(
            (BME_mean_k_given_hs, BME_var_k_given_hs, skewness)
            )
        if ck_cov_output:
            return mvs, BME_var_k_given_hs_cov
        else:
            return mvs

def _get_x_all_split(nk, zh, zs):
    '''
        Create the "estimated" observed values 
    for the estimation and observations
        For now, zero is used for estimation points
    (which should be specified as NaN)
        mean values are used for soft data
    '''
    x_all_split = []
    xk = numpy.empty((nk, 1))  # will be replaced later
    x_all_split.append(xk)
    x_all_split.append(zh) # e.g. xh
    if zs:
        xs = numpy.empty((len(zs), 1))
        for i, zsi in enumerate(zs):
            if get_standard_soft_pdf_type(zsi[0]) == 10:  # gaussian/normal
                xs[i] = zsi[1] #z_mean
            else:
                xs[i], dummy_v = proba2stat(
                    zsi[0],
                    numpy.array([zsi[1]]),
                    numpy.array([zsi[2]]),
                    numpy.array([zsi[3]])
                    )
        x_all_split.append(xs)
    else:
        x_all_split.append(None)
    return x_all_split

def _get_mean_all_split(x_all_split, order):
  '''
  Obtain the trend estimations at the estimation and data locations based 
  upon the specified trend order
  '''
  if numpy.isnan(order):  # zero mean
      mean_all_split = []
      for i in x_all_split:
          if i is not None:
              mean_all_split.append(numpy.zeros(i.shape))
          else:
              mean_all_split.append(None)
  elif order == 0:  # constant mean, exclude zk, average h and s
      xx=[x for x in x_all_split[1:] if x is not None]
      constant_mean_ = numpy.vstack(xx).mean()
      mean_all_split = []
      for i in x_all_split:
          if i is not None:
              mean_all_split.append(numpy.ones(i.shape)*constant_mean_)
          else:
              mean_all_split.append(None)
      #for means in mean_all_split:
      #  means[:] = constant_mean_
  return mean_all_split

def _get_cov_all_split(ck, ch, cs, covmodel, covparam):
  '''
  Obtain the covariance in the split ways.
  See coord2K
  '''
  return coord2Ksplit((ck, ch, cs), (ck, ch, cs),
                      covmodel, covparam)[0]

def _get_x(x_all_split, sub):
  '''
  Retrieve the estimated observed values given specified class, i.e., sub
  sub can be k, h, and s for estimation, hard, and soft data
  '''
  idx = [KHS_DICT[i] for i in sub]
  return numpy.vstack([x_all_split[i] for i in idx])

def _get_mean(mean_all_split, sub):
  '''
  Retrieve the expected values given specified class, i.e., sub
  sub can be k, h, and s for estimation, hard, and soft data
  '''
  idx = [KHS_DICT[i] for i in sub]     
  output=[mean_all_split[i] for i in idx if mean_all_split[i] is not None]
  if len(output)>0:
    return numpy.vstack(output)
  else:
    return None

def _get_sigma(cov_all_split, sub_a, sub_b, inv=False):
  '''
  Retrieve the cross-covaiance between specified class, i.e., sub_a and sub_b
  sub_a and sub_b can be k, h, and s for estimation, hard, and soft data
  '''
  idx_a = [KHS_DICT[i] for i in sub_a]
  idx_b = [KHS_DICT[i] for i in sub_b]

  cov_a_b = []
  for i in idx_a:
    output=[cov_all_split[i][j] for j in idx_b if cov_all_split[i][j] is not None]
    if len(output)>0:
      cov_a_b.append(numpy.hstack(output))
    else:
      cov_a_b.append(None)
  if len(cov_a_b)>1:    
    output2=[x for x in cov_a_b if x is not None]   
    cov_a_b = numpy.vstack(output2)
  else:
    cov_a_b = numpy.vstack(cov_a_b)
  if not inv:
    return cov_a_b
  else:
    return numpy.linalg.inv(cov_a_b)

def _get_mean_a_given_b(x_all_split, mean_all_split,
    cov_all_split, sub_a, sub_b):
    '''
    Obtain the conditonal mean a given b by using conditonal Gaussian formula
    '''      

    if 'k' not in sub_b:                    
        x_b = _get_x(x_all_split, sub_b)
        mean_a = _get_mean(mean_all_split, sub_a)
        mean_b = _get_mean(mean_all_split, sub_b)
        if mean_b is not None: # consider the case that data in sub_b does not exist
            sigma_a_b = _get_sigma(cov_all_split, sub_a, sub_b)
            inv_sigma_b_b = _get_sigma(cov_all_split, sub_b, sub_b, inv=True)
            output=mean_a + sigma_a_b.dot(inv_sigma_b_b).dot(x_b - mean_b)
        else:
           output=mean_a
        return output
    else:
        nlim=numpy.asarray(x_all_split[0]).size
        smtx=numpy.ones((1,nlim))
        if nlim>1:
            idx = [KHS_DICT[i] for i in sub_b if i is not 'k']
            xhs=numpy.vstack([x_all_split[i] for i in idx])
            x_b=[x_all_split[0].reshape((1,nlim)),xhs.dot(smtx)]
            x_b=numpy.vstack(x_b)
        else:            
           x_b = _get_x(x_all_split, sub_b)
        
        mean_a = _get_mean(mean_all_split, sub_a)
        mean_b = _get_mean(mean_all_split, sub_b)
        if mean_b is not None: # consider the case that data in sub_b does not exist
            sigma_a_b = _get_sigma(cov_all_split, sub_a, sub_b)
            inv_sigma_b_b = _get_sigma(cov_all_split, sub_b, sub_b, inv=True)
            output=mean_a + sigma_a_b.dot(inv_sigma_b_b).dot(x_b - mean_b)
        else:
            output=mean_a
        return output

def _get_sigma_a_given_b(cov_all_split, sub_a, sub_b):
  '''
  Obtain the conditional covariance a given b by using conditonal Gaussian 
  formula
  '''
  sigma_a_a = _get_sigma(cov_all_split, sub_a, sub_a)
  sigma_a_b = _get_sigma(cov_all_split, sub_a, sub_b)
  if sigma_a_b.size > 0:
    inv_sigma_b_b = _get_sigma(cov_all_split, sub_b, sub_b, inv=True)
    sigma_b_a = _get_sigma(cov_all_split, sub_b, sub_a)
    return sigma_a_a - sigma_a_b.dot(inv_sigma_b_b).dot(sigma_b_a)
  else:
    return sigma_a_a

def _get_multivariate_normal_pdf(x_all_split, mean_all_split,
    cov_all_split, sub_multi):
    '''
    Obtain multivariate Gaussian pdf or conditional multivariate Gaussian
    based upon the specified notations, i.e., sub_multi
    
    Note:
    sub_multi     string    h, s, and k for hard, soft and estimation locations
                            a_b represents a given b, e.g., k_h
    '''                                   
    if "_" in sub_multi:  # "given" type
        sub_a, sub_b = sub_multi.split('_')#x_sub.split('_') (temporarily change by HL)
        m = _get_mean_a_given_b(x_all_split, mean_all_split,
                                cov_all_split, sub_a, sub_b)
        v = _get_sigma_a_given_b(cov_all_split, sub_a, sub_b)
    else:  # single_sub
        sub_a = sub_multi
        m = _get_mean(mean_all_split, sub_a)
        v = _get_sigma(cov_all_split, sub_a, sub_a)
    return scipy.stats.multivariate_normal(m.T[0], v).pdf


def _get_fs(zs):
    '''
    the product of fs distributions
    x   ndim(e.g. npts) by ns
            x is a 2-D np array with the dimension of 
            ndim(number of samples at each integral) 
            by ns (the number of integrals, i.e., number 
            of soft data)      
    '''
    def fs(x):
        res = numpy.ones((x.shape[0],1))
        for idx_k, zsi in enumerate(zs):
            pdf_type = get_standard_soft_pdf_type(zsi[0])
            if pdf_type == 2:
                nl = zsi[1][0]
                limi = zsi[2]
                probdens = zsi[3]
                y_i = numpy.interp(
                    x[:,idx_k:idx_k+1], limi[:nl], probdens[:nl],
                    left = 0., right = 0.)
            elif pdf_type == 1:
                nl = zs[1]
                limi = zs[2]
                probdens = zs[3]
                idd = numpy.where(x[:, idx_k:idx_k+1] - limi[:nl] > 0)[0]
                y_i = probdens[idd]
            elif pdf_type == 10:
                zm = zsi[1]
                zstd = numpy.sqrt(zsi[2])
                try:
                    y_i = scipy.stats.norm.pdf(
                        x[:,idx_k:idx_k+1], loc=zm, scale=zstd)
                except FloatingPointError:
                    import pdb
                    pdb.set_trace()
              
            if not (y_i.all() or y_i.any()):
                return numpy.zeros((x.shape[0], 1))
            else:
                res *= y_i
        return res
    return fs

def _get_int_fg_a_given_b_fs_s(x_all_split, mean_all_split,
  cov_all_split, zs, sub_multi, sub_s='s'):
  '''
  The upper right part and lower right part of the last row of formula (1)
  The evaluation is based upon Eqns. (8) or (9) in the cases of s_h and s_kh 
  respectively
  '''

  #fg='s_kh', fs='s'
  sub_a, sub_b = sub_multi.split('_')
  sigma_a_given_b = _get_sigma_a_given_b(cov_all_split, sub_a, sub_b)
  try:
      inv_sigma_a_given_b = numpy.linalg.inv(sigma_a_given_b)
  except numpy.linalg.LinAlgError, e:
      import pdb
      pdb.set_trace()
      raise e
  # mean_tilde_s = zs[1]  # mean
  # sigma_tilde_s = numpy.diag(zs[2].T[0])  # cov matrix
  mean_tilde_s = []
  sigma_tilde_s = []
  for zsi in zs:
      mean_tilde_s.append([zsi[1]])
      sigma_tilde_s.append(zsi[2])
  mean_tilde_s = numpy.array(mean_tilde_s) # mean
  sigma_tilde_s = numpy.diag(sigma_tilde_s)  # cov matrix
  try:
      inv_sigma_tilde_s = numpy.linalg.inv(sigma_tilde_s)
  except numpy.linalg.LinAlgError, e:
      import pdb
      pdb.set_trace()
      raise e

  sigma_t = numpy.linalg.inv(inv_sigma_a_given_b + inv_sigma_tilde_s)
  det_sigma_t = numpy.linalg.det(sigma_t)
  det_sigma_a_given_b = numpy.linalg.det(sigma_a_given_b)
  det_sigma_tilde_s = numpy.linalg.det(sigma_tilde_s)
  ns = mean_tilde_s.shape[0]
  fgfs_front = numpy.sqrt(det_sigma_t) /\
        ((2*numpy.pi)**(ns/2.) *
         numpy.sqrt(det_sigma_a_given_b * det_sigma_tilde_s)
         )

  mean_a_given_b = _get_mean_a_given_b(
        x_all_split, mean_all_split, cov_all_split, sub_a, sub_b)
  fgfs_1 = numpy.diag((mean_a_given_b.T).dot(
        inv_sigma_a_given_b).dot(mean_a_given_b))
  fgfs_2 = (mean_tilde_s.T).dot(inv_sigma_tilde_s).dot(mean_tilde_s)
  fgfs_3 = (mean_a_given_b.T).dot(inv_sigma_a_given_b) +\
        (mean_tilde_s.T).dot(inv_sigma_tilde_s)
  fgfs_4 = inv_sigma_tilde_s.dot(mean_tilde_s) +\
        inv_sigma_a_given_b.dot(mean_a_given_b)
  fgfs_end = numpy.exp(
        (-1/2.) * (fgfs_1 + fgfs_2 - numpy.diag(fgfs_3.dot(sigma_t).dot(fgfs_4))))
  fgfs_end=fgfs_end.reshape(fgfs_end.size)

  return fgfs_front * fgfs_end, \
        (sigma_t, inv_sigma_a_given_b, inv_sigma_tilde_s, mean_tilde_s)

def _changetimeform(ck,ch=None,cs=None):
  '''
  Change the time format into float while it is in datetime format
  '''  
  
  if type(ck[0,-1])==numpy.datetime64:
    origin=ck[0,-1]
    ck[:,-1]=numpy.double(numpy.asarray(ck[:,-1],dtype='datetime64')-origin)
    ck=ck.astype(numpy.double)
    if ch is not None and ch.size>0:
      if (not type(ch[0,-1]==numpy.datetime64)):
        print 'Time format of ch is not consistent with ck (np.datetime64)'
        raise
      ch[:,-1]=numpy.double(numpy.asarray(ch[:,-1],dtype='datetime64')-origin)
      ch=ch.astype(numpy.double)
    if cs is not None and cs.size>0:
      if (not type(cs[0,-1]==numpy.datetime64)):
        print 'Time format of cs is not consistent with ck (np.datetime64)'
        raise
      cs[:,-1]=numpy.double(numpy.asarray(cs[:,-1],dtype='datetime64')-origin)
      cs=cs.astype(numpy.double)
      
  return ck,ch,cs

def _set_nh_ns(ck,ch,cs,nhmax,nsmax,dmax):
  '''
  Set the size of nhmax and nsmax that limits the size of matrix to be allocated
  it can be important for an efficient S/T estimation
  '''
  if dmax is not None and numpy.all(dmax):
    return nhmax,nsmax,dmax
  
  if ck[0,:].size<3:
    if dmax is None:
      dmax_=0
      if ch is not None:
        ch=numpy.array(ch,ndmin=2)
        maxd_h=pdist(ch).max()
        dmax_=numpy.max([dmax_,maxd_h])
      if cs is not None:
        cs=numpy.array(cs,ndmin=2)
        maxd_s=pdist(cs).max()
        dmax_=numpy.max([dmax_,maxd_s])
      dmax=numpy.array(dmax_).reshape(1,1)

    if nhmax is None:
      if ch is not None:
        nhmax=ch.shape[0]
      else:
        nhmax=0

    if nsmax is None:
      if cs is not None:
        nsmax=cs.shape[0]
      else:
        nsmax=0

    
  else:
    maxd=0
    maxt=0
    if dmax is None:
      if ch is not None:
        dummy=numpy.random.rand(ch.shape[0],1)
        _,cMS_h,tME_h,_=valstv2stg(ch,dummy)
        if nhmax is None:
          nhmax=cMS_h.shape[0]*3
        maxd_h=pdist(cMS_h).max()
        maxt_h=pdist(tME_h.reshape((tME_h.size,1))).max()
      else:
        maxd_h=0
        maxt_h=0
        nhmax=0
      maxd=numpy.max([maxd_h,maxd]) 
      maxt=numpy.max([maxt_h,maxt])
      if cs is not None: 
        dummy=numpy.random.rand(cs.shape[0],1)
        _,cMS_s,tME_s,_=valstv2stg(cs,dummy)
        if nsmax is None:
          if zs[0]==10 or zs[0] is 'gaussian':
            nsmax=cMS_s.shape[0]*3
          else:
            nsmax=3
        maxd_s=pdist(cMS_s).max()
        maxt_s=pdist(tME_s.reshape((tME_s.size,1))).max()
        maxd=numpy.max([maxd_s,maxd])
        maxt=numpy.max([maxt_s,maxt])
      else:
        nsmax=0
        maxd_s=0
        maxt_s=0
      maxd=numpy.max([maxd_s,maxd])
      maxt=numpy.max([maxt_s,maxt])
    dmax=numpy.array([maxd,maxt,numpy.nan]).reshape(1,3)

  return nhmax,nsmax,dmax

def _stratio(covparam):
  '''
  Estimate the S/T ratio for dmax
  '''
  nm=len(covparam)
  sills= numpy.array([covparam[k][0] for k in xrange(nm)])  
  hrange = numpy.array([covparam[k][1][0] for k in xrange(nm)]) 
  idx0 = numpy.where([hrange[k] is not None for k in xrange(nm)])[0]
  idx=numpy.where(sills[idx0]==sills.max())[0]
  ratio=covparam[idx0[idx]][1][0]/covparam[idx0[idx]][2][0]
  return ratio

def _bme_posterior_prepare(
    ck, ch=None, cs=None, zh=None, zs=None,
    covmodel=None, covparam=None, covmat=None,
    order=numpy.nan, options=None,
    nhmax=None, nsmax=None, dmax=None,
    general_knowledge='gaussian',
    #  specific_knowledge='unknown',  
    pdfk=None,pdfh=None,pdfs=None,hk_k=None,hk_h=None,hk_s=None,
    gui_args=None):

    '''
    check and configure arguments and 
        find neighbor ckhs index for bme posterior calculation

    ckhs_idx_list:  [ck_idx, ch_idx, cs_idx] represents
        these ck have the same neighbors ch and cs

    return (output_arguments, configured_arguments):
        a tuple contain arguments
    '''

    if covmat is None:
        if (covmodel is None) or (covparam is None):
            raise ValueError(
                'Covariance model and their associated parameters '\
                'should be specified if no covarinace matrix provided.')

    dk = ck.shape[1]
    nk = ck.shape[0]
    nh = ch.shape[0] if ch is not None else 0
    ns = cs.shape[0] if cs is not None else 0

    ck, ch, cs = _changetimeform(ck, ch, cs)
    nhmax, nsmax, dmax = _set_nh_ns(ck, ch, cs, nhmax, nsmax, dmax)
    if dmax.size == 3 and numpy.isnan(dmax[0][2]):
        dmax[0][2] = _stratio(covparam)
    stratio = dmax[0][2] if dk == 3 else 1.
      
    if options is None:
        options = BMEoptions()

    if gui_args:
        qpgd = gui_args[0]
    
    if general_knowledge == 'gaussian':
        order = get_standard_order(order)

        #aggregate ck for same hard data and soft data
        # chs = numpy.vstack(ch, cs)
        ck_norm = numpy.copy(ck)
        ck_norm[:, -1] = ck_norm[:, -1] * stratio
        if dk == 3:
            dmax_norm = (dmax[0][0]**2 + (dmax[0][1] * stratio)**2)**0.5
        else:
            dmax_norm = dmax[0][0]

        if isinstance(ch, numpy.ndarray) and nhmax != 0:
            ch_norm = numpy.copy(ch)
            ch_norm[:, -1] = ch_norm[:, -1] * stratio
            ch_tree = cKDTree(ch_norm)
        if isinstance(cs, numpy.ndarray) and nsmax != 0:
            cs_norm = numpy.copy(cs)
            cs_norm[:, -1] = cs_norm[:, -1] * stratio
            cs_tree = cKDTree(cs_norm)

        ckhs_idx_list = []
        if isinstance(ch, numpy.ndarray) and nhmax != 0: #has harddata
            ch_ck_dict =\
                neighbours_index_kd(ck_norm, ch_tree, nhmax, dmax_norm)
            for ch_idx, ck_idx in ch_ck_dict.iteritems():
                if isinstance(cs, numpy.ndarray) and nsmax != 0: #both hard and soft
                    picked_ck_norm = ck_norm[ck_idx, :]
                    cs_ck_dict =\
                        neighbours_index_kd(
                            picked_ck_norm, cs_tree, nsmax, dmax_norm
                            )
                    for cs_idx, ck2_idx in cs_ck_dict.iteritems():
                        ck_idx = numpy.array(ck_idx)
                        ckhs_idx_list.append(
                            [ck_idx[ck2_idx,], ch_idx, cs_idx]
                            )
                else: #only harddata
                    ckhs_idx_list.append([ck_idx, ch_idx, ()])
        elif isinstance(cs, numpy.ndarray) and nsmax != 0: #only softdata
            cs_ck_dict =\
                neighbours_index_kd(ck_norm, cs_tree, nsmax, dmax_norm)
            for cs_idx, ck_idx in cs_ck_dict.iteritems():
                ckhs_idx_list.append([ck_idx, (), cs_idx])
    else:
        raise ValueError("Now we can not consider non-gaussian GK." )

    configured_arguments =\
        (ck, ch, cs, zh, zs,
        covmodel, covparam, covmat,
        order, options,
        nhmax, nsmax, dmax,
        general_knowledge,
        pdfk, pdfh, pdfs, hk_k, hk_h, hk_s,
        gui_args)
    output_arguments = \
        (ckhs_idx_list,)
    return (output_arguments, configured_arguments)

def BMEPosteriorMoments(
    ck, ch=None, cs=None, zh=None, zs=None,
    covmodel=None, covparam=None, covmat=None,
    order=numpy.nan, options=None,
    nhmax=None, nsmax=None, dmax=None,
    general_knowledge='gaussian',
    #  specific_knowledge='unknown',  
    pdfk=None,pdfh=None,pdfs=None,hk_k=None,hk_h=None,hk_s=None,
    gui_args=None):
    '''
    zs: a sequence of soft data, e.g. (zs1, zs2, zs3, ..., zsn)
        each zsi(i=1~n) is a sequence of data arguments,
        first item should be softpdftype to determind
        the other rest arguments format, e.g.
        syntax: (softdata_type, *softdata_args)
            zs1 = (1, nl, limi, probdens)
            zs2 = (10, zm, zstd)
        if element is emtpy, put None
            e.g. (zs1, zs2, None, ..., zsn)

    covmat:
        covariance matrix, a numpy 2d array
        with shape (nk+nh+ns) by (nk+nh+ns)
    if covmat provieded, covmodel and covparam are simply skipped.
    
    #SI = integrate (fg_s_given_kh * fs_s) dx_s
    #NC = integrate (fg_s_given_h * fs_s) dx_s
    #pdf_k = (fg_kh * SI) / (fg_h * NC)  # eq.1
    #exp_k = ... # eq.2
    #exp_kp = ... # eq.3
    #if general_knowledge == gaussian and specific_knowledge == unknown
    #exp_k = ... # eq.4
    #var_k = ... # eq.5

    gui_args: a tuple with gui arguments

    
    return 
    '''
    (output_arguments, configured_arguments) = _bme_posterior_prepare(
        ck, ch, cs, zh, zs,
        covmodel, covparam, covmat,
        order, options,
        nhmax, nsmax, dmax,
        general_knowledge,
        pdfk, pdfh, pdfs, hk_k, hk_h, hk_s,
        gui_args)
    (ckhs_idx_list,) = output_arguments

    (ck, ch, cs, zh, zs,
        covmodel, covparam, covmat,
        order, options,
        nhmax, nsmax, dmax,
        general_knowledge,
        pdfk, pdfh, pdfs, hk_k, hk_h, hk_s,
        gui_args) = configured_arguments

    zk = numpy.empty((ck.shape[0],3)) #to 3rd moments
    if general_knowledge == 'gaussian':
        for ck_idx, ch_idx, cs_idx in ckhs_idx_list:
            ck_idx = numpy.array(ck_idx, dtype=int)
            ch_idx = numpy.array(ch_idx, dtype=int)
            cs_idx = numpy.array(cs_idx, dtype=int)
            picked_ch =\
                ch[ch_idx, :] if isinstance(ch, numpy.ndarray) else None
            picked_cs =\
                cs[cs_idx, :] if isinstance(cs, numpy.ndarray) else None
            picked_zh =\
                zh[ch_idx, :] if isinstance(zh, numpy.ndarray) else None
            picked_zs =\
                [zs[cs_idx_i] for cs_idx_i in cs_idx] if zs is not None else None
            ck_count = ck_idx.shape[0]
            split_count = numpy.ceil(ck_count/250.)
            for ck_idx_piece in numpy.array_split(ck_idx, split_count):
                picked_ck = ck[ck_idx_piece, :]
                if covmat is not None:
                    covidx = numpy.hstack(
                        (ck_idx_piece, nk+ch_idx, nk+nh+cs_idx)
                        )
                    picked_covmat = covmat[numpy.ix_(covidx, covidx)]
                else:
                    picked_covmat = covmat

                try:
                    picked_mvs = _bme_posterior_moments(
                        picked_ck, picked_ch, picked_cs,
                        picked_zh, picked_zs,
                        covmodel, covparam, picked_covmat,
                        order, options, general_knowledge,
                        pdfk, pdfh, pdfs, hk_k, hk_h, hk_s,
                        gui_args)
                except numpy.linalg.LinAlgError, e:
                    import pdb
                    pdb.set_trace()
                    raise e

                
                zk[ck_idx_piece, :] = picked_mvs
                if gui_args:
                    if gui_args[0].wasCanceled(): #cancel by user
                        return False
                    else:
                        gui_args[0].setValue(gui_args[0].value()+ck_idx_piece.size)
        return zk
    else:
      nk=len(pdfk)
      moments=numpy.empty((nk,3))

      for k in xrange(nk):
        print 'BME MOMENTS:' + str(k+1) + '/' + str(nk)
        
        cklocal=ck[k:k+1,:]
        pdfk_local=[pdfk[k]]
        hk_k_local=[hk_k[k]]
        
        pdf_k=BMEPosteriorPDF(cklocal, ch, cs, zh, zs, covmodel, covparam,
              order, options, nhmax, nsmax, dmax, general_knowledge,
              pdfk=pdfk_local,pdfh=pdfh,pdfs=pdfs,
              hk_k=hk_k_local,hk_h=hk_h,hk_s=hk_s)[0]
          
        zmin=hk_k[k][0]-6*numpy.sqrt(hk_k[k][1])
        zmax=hk_k[k][0]+6*numpy.sqrt(hk_k[k][1])
        
        xxx=numpy.linspace(zmin,zmax,100)
        aaa=pdf_k(xxx,0)

        maxpts = options[2][0]
        aEps = 0
        rEps = options[3][0]

        from cubature import cubature

        mon1_for_cubature = lambda x_array: x_array[:,0] * pdf_k(x_array[:,0],0)[:,0]
        mon1,mon1_err = cubature(
            func=mon1_for_cubature, ndim=1, fdim=1, xmin=numpy.array([zmin]),
            xmax=numpy.array([zmax]), adaptive='h', maxEval = maxpts,
            abserr = 0, relerr = rEps, vectorized = True)
        mon2_for_cubature = lambda x_array: x_array[:,0]**2 * pdf_k(x_array[:,0],0)[:,0]
        mon2,mon2_err = cubature(
            func=mon2_for_cubature,ndim=1, fdim=1, xmin=numpy.array([zmin]),
            xmax=numpy.array([zmax]), adaptive='h', maxEval = maxpts,
            abserr = 0, relerr = rEps, vectorized = True)  
        mon3_for_cubature = lambda x_array: x_array[:,0]**3 * pdf_k(x_array[:,0],0)[:,0]
        mon3,mon3_err = cubature(
            func=mon3_for_cubature,ndim=1, fdim=1, xmin=numpy.array([zmin]),
            xmax=numpy.array([zmax]), adaptive='h', maxEval = maxpts,
            abserr = 0, relerr = rEps, vectorized = True)

        moments[k,0]=mon1
        moments[k,1]=mon2-mon1**2
        moments[k,2]=mon3-3*mon1*mon2-mon1**3

      return moments[:,0], moments[:,1], moments[:,2]

def BMEPosteriorPDF(
    ck, ch=None, cs=None, zh=None, zs=None,
    covmodel=None, covparam=None, covmat=None,
    order=numpy.nan, options=None,
    nhmax=None, nsmax=None, dmax=None,
    general_knowledge='gaussian',
    #  specific_knowledge='unknown',  
    pdfk=None,pdfh=None,pdfs=None,hk_k=None,hk_h=None,hk_s=None,
    gui_args=None):

    (output_arguments, configured_arguments) = _bme_posterior_prepare(
        ck, ch, cs, zh, zs,
        covmodel, covparam, covmat,
        order, options,
        nhmax, nsmax, dmax,
        general_knowledge,
        pdfk, pdfh, pdfs, hk_k, hk_h, hk_s,
        gui_args)
    (ckhs_idx_list,) = output_arguments

    (ck, ch, cs, zh, zs,
        covmodel, covparam, covmat,
        order, options,
        nhmax, nsmax, dmax,
        general_knowledge,
        pdfk, pdfh, pdfs, hk_k, hk_h, hk_s,
        gui_args) = configured_arguments

    zk = numpy.empty((ck.shape[0],1), dtype=object) # to 1 pdf function
    if general_knowledge == 'gaussian':
        for ck_idx, ch_idx, cs_idx in ckhs_idx_list:
            ck_idx = numpy.array(ck_idx, dtype=int)
            ch_idx = numpy.array(ch_idx, dtype=int)
            cs_idx = numpy.array(cs_idx, dtype=int)
            picked_ch =\
                ch[ch_idx, :] if isinstance(ch, numpy.ndarray) else None
            picked_cs =\
                cs[cs_idx, :] if isinstance(cs, numpy.ndarray) else None
            picked_zh =\
                zh[ch_idx, :] if isinstance(zh, numpy.ndarray) else None
            picked_zs =\
                [zs[cs_idx_i] for cs_idx_i in cs_idx] if zs is not None else None
            ck_count = ck_idx.shape[0]
            split_count = numpy.ceil(ck_count/250.)
            for ck_idx_piece in numpy.array_split(ck_idx, split_count):
                picked_ck = ck[ck_idx_piece, :]
                if covmat is not None:
                    covidx = numpy.hstack(
                        (ck_idx_piece, nk+ch_idx, nk+nh+cs_idx)
                        )
                    picked_covmat = covmat[numpy.ix_(covidx, covidx)]
                else:
                    picked_covmat = covmat
                try:
                    picked_mvs = _bme_posterior_pdf(
                        picked_ck, picked_ch, picked_cs,
                        picked_zh, picked_zs,
                        covmodel, covparam, picked_covmat,
                        order, options, general_knowledge,
                        pdfk, pdfh, pdfs, hk_k, hk_h, hk_s,
                        gui_args)
                except Exception, e:
                    import pdb
                    pdb.set_trace()
                    raise e

                zk[ck_idx_piece, :] = picked_mvs
                if gui_args:
                    if gui_args[0].wasCanceled(): #cancel by user
                        return False
                    else:
                        gui_args[0].setValue(gui_args[0].value()+ck_idx_piece.size)
        return zk

def BMEPosteriorPDF_backup(
  ck, ch, cs, zh, zs=None,
  covmodel=None, covparam=None,
  order=numpy.nan, options=None,
  nhmax=None, nsmax=None, dmax=None,
  general_knowledge='gaussian',
  pdfk=None,pdfh=None,pdfs=None,hk_k=None,hk_h=None,hk_s=None):
  '''
  To obtain the BME posterior PDF with specified general and specific 
  knowledge PDFs
    
  Input:
  ck    N by 3    2D array of the S/T coordinates of estimation points
  ch    N by 3    2D array of the S/T coordinates of hard data
  cs    N by 3    2D array of the S/T coordinates of soft data
  zh    N by 1    2D array to specify the observed values at ch
  zs    N by k    2D array to specify the uncertain observed values at cs 
                  with the format as follows. 
                  zs = (softpdftype, mean, variance)
                  zs = (softpdftype, nl, limi, probadens)
                  The more details can go to see the function 
                  get_standard_soft_pdf_type function in starpy.bme.pystks_variable.py
  order integer   to specify the trend forms. NaN and 0 for zero and contant     
                  or string means respectively        
  options         BME options (look into BMEprobaMoments?)



  Note: the Gaussian part should be moved back into the proper places in order
  to generalize this function         
                    
  '''

    #SI = integrate (fg_s_given_kh * fs_s) dx_s
    #NC = integrate (fg_s_given_h * fs_s) dx_s
    #pdf_k = (fg_kh * SI) / (fg_h * NC)  # eq.1
    #exp_k = ... # eq.2
    #exp_kp = ... # eq.3
    #if general_knowledge == gaussian and specific_knowledge == unknown
    #exp_k = ... # eq.4
    #var_k = ... # eq.5

  # fg is gaussian, i.e., the Matlab version case

  ck,ch,cs=_changetimeform(ck,ch,cs)
  
  nhmax,nsmax,dmax=_set_nh_ns(ck,ch,cs,nhmax,nsmax,dmax)
  if dmax.size == 3 and dmax[0][2] is numpy.nan:
    dmax[0][2]=_stratio(covparam)

  if options is None:
    options=BMEoptions()
  
  if general_knowledge == 'gaussian':
    
    if (covmodel is None) or (covparam is None):
      print 'covariance model and their associated parameters should be specified'

    else:
      order = get_standard_order(order)
      nk = ck.shape[0]
      x_all_split = _get_x_all_split(nk, zh, zs)
      mean_all_split = _get_mean_all_split(x_all_split, order)
      cov_all_split = _get_cov_all_split(ck, ch, cs, covmodel, covparam)
      
      def fg_kh(xk):
        xk=numpy.array(xk)
        nlim=xk.size
        if nlim==1:
          output=_get_multivariate_normal_pdf(
              x_all_split, mean_all_split, cov_all_split, 'kh')(
                  numpy.vstack((xk, _get_x(x_all_split, 'h'))).T)
        else:
          xlim=[xk.reshape((nlim,1)),\
                numpy.ones((nlim,1)).dot(_get_x(x_all_split,'h').T)]
          output=_get_multivariate_normal_pdf(
              x_all_split, mean_all_split, cov_all_split, 'kh')(
              numpy.hstack(xlim))
        return output
              
      # fg_kh = lambda xk: _get_multivariate_normal_pdf(
      #         x_all_split, mean_all_split, cov_all_split, 'kh')(
      #            numpy.vstack((xk, _get_x(x_all_split, 'h'))).T)
      fg_h = _get_multivariate_normal_pdf(
              x_all_split, mean_all_split, cov_all_split, 'h')(
                  _get_x(x_all_split, 'h').T)
      if zs is None:
        return lambda xk: (fg_kh(xk)/fg_h)#[0][0]
      else:
        softpdftype = get_standard_soft_pdf_type(zs[0])
        if softpdftype == 10: #specific_knowledge == 'gaussian':
          SI = lambda xk: _get_int_fg_a_given_b_fs_s(
                  [xk] + x_all_split[1:], mean_all_split,
                  cov_all_split, zs, 's_kh', 's')[0]
          NC = _get_int_fg_a_given_b_fs_s(
                  x_all_split, mean_all_split,
                  cov_all_split, zs, 's_h', 's')[0]
          return lambda xk: (fg_kh(xk) * SI(xk) / fg_h / NC)#[0][0]
        elif softpdftype == 1:
          pass # to do
        elif softpdftype == 2:
          pass # to do
  else:
    nk = ck.shape[0]
    if len(pdfk) != nk:
      print 'The number of pdfk functions is not equal'+\
            'to the number of estimation locations'
      raise
    if len(hk_k) != nk:
      print 'Then number of hk_k is not equal to the number of estimation locations'
      raise
    

    mpdf=[None]*nk
    pdf_k=[None]*nk
    pdfs_kh=[None]*nk
    pdfs_h=[None]*nk
    zhlocal=[None]*nk
    zslocal=[None]*nk
    nklocal=1
    
    
    maxpts = options[2][0]
    aEps = 0
    rEps = options[3][0]
        
    for k in xrange(nk):
      
      if nk>1:
        print 'BMEPDF:' + str(k+1) + '/' + str(nk)
      
      cklocal=ck[k:k+1,:]
      pdfklocal=[pdfk[k]]
      hk_k_local=[hk_k[k]]
      
      chlocal, zhlocal[k], dhlocal, sumnhlocal, idxhlocal = \
            neighbours( cklocal, ch, zh, nhmax, dmax )
      pdfhlocal=[pdfh[m] for m in idxhlocal]
      hk_h_local=[hk_h[m] for m in idxhlocal]
      
      if cs is not None:
        zsdummy=numpy.empty((cs.shape[0],1))
        cslocal, zslocals, dslocal, sumnslocal, idxslocal = \
              neighbours( cklocal, cs, zsdummy, nsmax, dmax )
        pdfslocal=[pdfs[m] for m in idxslocal]
        hk_s_local=[hk_s[m] for m in idxslocal]
        idxslocal=idxslocal.flat
        if len(idxslocal)>0:
          zslocal[k]=[zs[0],zs[1][idxslocal,:],zs[2][idxslocal,:],zs[3][idxslocal,:]]#[zs[1][m],zs[2][m],zs[3][m]] for m in idxslocal]#[zs[m] for m in idxslocal]#              
      else:
        cslocal=cs
        zslocal=zs
        hk_s_local=hk_s
        pdfslocal=None
      
      x_all_split = _get_x_all_split(nklocal, zhlocal[k], zslocal[k])
      mean_all_split = _get_mean_all_split(x_all_split, order)
      
      # Make the covariance function into correlation function
      
      var=numpy.sum([covparam[m][0] for m in xrange(len(covparam))]) 
      for m in xrange(len(covparam)):
        covparam[m][0]=covparam[m][0]/var   
      
      cov_all_split = _get_cov_all_split(cklocal, chlocal, cslocal, covmodel, covparam)  
      
      if ch is not None:
        cov_kh=_get_sigma(cov_all_split, 'kh', 'kh', inv=False)
       # cov_hh=_get_sigma(cov_all_split, 'h', 'h', inv=False)
        cov_skh = _get_sigma(cov_all_split, 'skh', 'skh', inv=False)
        cov_sh = _get_sigma(cov_all_split, 'sh', 'sh', inv=False)
        
      if not isinstance(pdfk,list):
        pdfk=[pdfk]
      if not isinstance(hk_k,list):
        hk_k=[hk_k]

      pdf_k[k],_=maxentcondpdf_gc(ppdf=pdfklocal+pdfhlocal,R=cov_kh,
                             hk=hk_k_local+hk_h_local,k_num=len(pdfklocal))  
      if zslocal[k] is not None:                       
        pdfs_kh[k],_=maxentcondpdf_gc(ppdf=pdfslocal+pdfklocal+pdfhlocal,R=cov_skh,
                             hk=hk_s_local+hk_k_local+hk_h_local,
                             k_num=len(pdfslocal))
        pdfs_h[k],_=maxentcondpdf_gc(ppdf=pdfslocal+pdfhlocal,R=cov_sh,
                             hk=hk_s_local+hk_h_local,
                             k_num=len(pdfslocal))


                            
      # write up a pyallmoments here 
      # to integrate the softdata into the equation (1) calculation 
      # in the BME_OP_chapter


      #pdfs_h[k]                       

      # mpdf_kh,_=maxentpdf_gc(ppdf=pdfklocal+pdfhlocal,R=cov_kh,
      #                       hk=hk_k_local+hk_h_local)
      # mpdf_hh,_=maxentpdf_gc(pdfhlocal,cov_hh,hk_h)

      def mpdfk(xk,k):
        xk=numpy.array(xk)
        xk=xk.reshape((xk.size,1))
        nlim=xk.size
        up=numpy.empty((nlim,1))
        
        zh_n=numpy.ones((nlim,1)).dot(zhlocal[k].T)
        xkzh=numpy.hstack([xk,zh_n])
        pdfk=pdf_k[k](xkzh)
        
        if zslocal[k] is not None:
          up,_,_=pyAllMomentsNG(zslocal[k], xkzh, pdfs_kh[k],aEps,rEps,maxpts)
          bottom,_,_=pyAllMomentsNG(zslocal[k], zhlocal[k].T, pdfs_h[k],aEps,rEps,maxpts)                
          pdfk=pdfk*up/bottom

        return pdfk

      mpdf[k] = mpdfk
    
    return mpdf

def BMEprobaGaussian(ck, ch, cs, zh, zs=None,
  covmodel=None, covparam=None, order=numpy.nan,
  nhmax=None, nsmax=None, dmax=None, gui_args=None):
    
  '''
  The BME function considers both general and specific knowledges are Gaussian
  This function can consider the hard-only or soft-only data cases

  zs maybe = [] (empty list for no soft data)
  'if zs is not None' should' be 'if zs'
  '''
  
  ck, ch, cs = _changetimeform(ck, ch, cs)
  if nhmax is None:
      nhmax, nsmax, dmax = _set_nh_ns(ck, ch, cs, nhmax, nsmax, dmax)
    
  if zs:
      # should add GUI from here
      order = get_standard_order(order)
      #softpdftype = get_standard_soft_pdf_type(zs[0])
      nk = ck.shape[0]
      if zh is not None:
          nh = zh.shape[0]
      else:
          nh = 0
      # ns = zs[1].shape[0]
      ns = cs.shape[0] if cs is not None else 0
      x_all_split = _get_x_all_split(nk, zh, zs)
      mean_all_split = _get_mean_all_split(x_all_split, order)
      cov_all_split = _get_cov_all_split(ck, ch, cs, covmodel, covparam)
      
      mean_k = _get_mean(mean_all_split, 'k')
      mean_hs = _get_mean(mean_all_split, 'hs')
      mean_s_given_h = _get_mean_a_given_b(
                x_all_split, mean_all_split,
                cov_all_split,sub_a='s', sub_b='h')

      NC, (sigma_t_prime, inv_sigma_s_given_h,inv_sigma_tilde_s, mean_tilde_s) =\
          _get_int_fg_a_given_b_fs_s(x_all_split, mean_all_split,
                                     cov_all_split, zs, 
                                     sub_multi = 's_h', sub_s='s')

      hat_x_s = NC * sigma_t_prime.dot(
          inv_sigma_s_given_h.dot(mean_s_given_h)
          + inv_sigma_tilde_s.dot(mean_tilde_s)
          )
      if nh>0:
        hat_x_h = _get_x(x_all_split, 'h') * NC
        hat_x_hs = numpy.vstack((hat_x_h, hat_x_s))
      else:
        hat_x_hs = hat_x_s

      sigma_k_hs = _get_sigma(
          cov_all_split, sub_a='k', sub_b='hs')

      inv_sigma_hs_hs = _get_sigma(
          cov_all_split, sub_a='hs', sub_b='hs', inv=True)

      cond_k_hs = sigma_k_hs.dot(inv_sigma_hs_hs)

      BME_mean_k_given_hs_a = cond_k_hs.dot(hat_x_hs)

      BME_mean_k_given_hs_b = mean_k - cond_k_hs.dot(mean_hs)

      BME_mean_k_given_hs =\
          BME_mean_k_given_hs_a / NC + BME_mean_k_given_hs_b

      sigma_k_given_hs = _get_sigma_a_given_b(
          cov_all_split, sub_a='k', sub_b='hs')

      sigma_k_given_hs_diag =\
            sigma_k_given_hs.diagonal().reshape((-1,1))

      mean_t = hat_x_hs/NC
      aa = numpy.zeros(inv_sigma_hs_hs.shape)
      aa[nh:nh+ns, nh:nh+ns] = sigma_t_prime*NC
      bb = (mean_t - mean_hs).dot((mean_t - mean_hs).T) * NC
      tt = cond_k_hs.dot(aa + bb).dot(cond_k_hs.T)
      tt_diag = tt.diagonal().reshape((-1,1))

      BME_var_k_given_hs = (
        sigma_k_given_hs_diag - BME_mean_k_given_hs**2 + mean_k**2
        - 2*mean_k * cond_k_hs.dot(mean_hs)
        + 2*mean_k * cond_k_hs.dot(hat_x_hs) / NC
        + tt_diag / NC
        )
      
      skewness = numpy.zeros(BME_mean_k_given_hs.shape)
      mvs = numpy.hstack(
          (BME_mean_k_given_hs, BME_var_k_given_hs, skewness)
          )
      return mvs
  else: # only hard data
    order = get_standard_order(order)
    #softpdftype = get_standard_soft_pdf_type(zs[0])
    nk = ck.shape[0]
    dm = ck[0].size
    
    if dm<3 or nk<100:
        nh = zh.shape[0]
        x_all_split = _get_x_all_split(nk, zh, zs)
        mean_all_split = _get_mean_all_split(x_all_split, order)
        cov_all_split = _get_cov_all_split(ck, ch, cs, covmodel, covparam)
          
        # mean_k=_get_mean(mean_all_split,'k')   
        mean_k_given_h = _get_mean_a_given_b(
            x_all_split, mean_all_split,
            cov_all_split,sub_a='k', sub_b='h')
        sigma_k_given_h = _get_sigma_a_given_b(
            cov_all_split, sub_a='k', sub_b='h')
        return mean_k_given_h, sigma_k_given_h
    else:
        dummy = numpy.random.rand(ck.shape[0],1)
        _, cMS_k, tME_k, _ = valstv2stg(ck, dummy)
        nklocal = cMS_k.shape[0]
        mean_k_given_h = numpy.empty((nklocal, 0))
        sigma_k_given_h = numpy.empty((nklocal, 0))

        # Here should add spatial split for large spatial data at a time
        # or GUI will become freezed
        for tt in xrange(tME_k.size):
            cklocal =\
                numpy.hstack([
                    numpy.mean(cMS_k,0), tME_k[tt]
                    ]).reshape(1, 3)
            cklocals =\
                numpy.hstack([
                    cMS_k, numpy.ones((nklocal, 1))*tME_k[tt]
                    ])
            chlocal, zhlocal, dhlocal, sumnhlocal, idxhlocal = \
                neighbours(cklocal, ch, zh, nhmax, dmax)
            cslocal = cs
            zslocal = zs

            x_all_split = _get_x_all_split(nklocal, zhlocal, zslocal)
            mean_all_split = _get_mean_all_split(x_all_split, order)
            cov_all_split = _get_cov_all_split(
                cklocals, chlocal, cslocal, covmodel, covparam)
          
            mean_k_given_h_ = _get_mean_a_given_b(
                      x_all_split, mean_all_split,
                      cov_all_split,sub_a='k', sub_b='h')
            sigma_k_given_h_ = numpy.diag(
                _get_sigma_a_given_b(
                    cov_all_split, sub_a='k', sub_b='h'
                    )
                ).reshape(nklocal, 1)
                      
            mean_k_given_h=numpy.hstack([mean_k_given_h,mean_k_given_h_])
            sigma_k_given_h=numpy.hstack([sigma_k_given_h,sigma_k_given_h_])
            print str(tt+1) + '/' + str(tME_k.size)
            if gui_args:
                gui_args[0].setValue(cMS_k.shape[0]*(tt+1))
        ck2,mean_k_given_h_v=valstg2stv(mean_k_given_h, cMS_k, tME_k)
        ck2,sigma_k_given_h_v=valstg2stv(sigma_k_given_h, cMS_k, tME_k)
        
        # ck != ck2 will occur when ck is not get from grid input
        # need to be fixed ASAP.
        if not numpy.all(ck2==ck):
            print 'warning: ck and ck2 are not the same'
            raise ValueError('Now ck only can input with grid.')

        return mean_k_given_h_v, sigma_k_given_h_v
