从高斯pm,pv到高斯qm,q的KullbackLeibler散度

2024-06-16 11:14:38 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图计算从高斯1到高斯2的Kullback-Leibler散度 我有高斯函数的平均值和标准差 我从http://www.cs.cmu.edu/~chanwook/MySoftware/rm1_Spk-by-Spk_MLLR/rm1_PNCC_MLLR_1/rm1/python/sphinx/divergence.py尝试了这段代码

def gau_kl(pm, pv, qm, qv):
    """
    Kullback-Leibler divergence from Gaussian pm,pv to Gaussian qm,qv.
    Also computes KL divergence from a single Gaussian pm,pv to a set
    of Gaussians qm,qv.
    Diagonal covariances are assumed.  Divergence is expressed in nats.
    """
    if (len(qm.shape) == 2):
        axis = 1
    else:
        axis = 0
    # Determinants of diagonal covariances pv, qv
    dpv = pv.prod()
    dqv = qv.prod(axis)
    # Inverse of diagonal covariance qv
    iqv = 1./qv
    # Difference between means pm, qm
    diff = qm - pm
    return (0.5 *
            (numpy.log(dqv / dpv)            # log |\Sigma_q| / |\Sigma_p|
             + (iqv * pv).sum(axis)          # + tr(\Sigma_q^{-1} * \Sigma_p)
             + (diff * iqv * diff).sum(axis) # + (\mu_q-\mu_p)^T\Sigma_q^{-1}(\mu_q-\mu_p)
             - len(pm)))                     # - N

我使用平均值和标准偏差作为输入,但是代码的最后一行(len(pm))会导致错误,因为平均值是一个数字,我不理解这里的len函数。在

注意。两组(即高斯数)不相等,所以我不能用scipy.stats.熵在


Tags: oflendiffgaussiansigma平均值pvaxis
2条回答

如果你还感兴趣。。。在

该函数期望多元高斯协方差矩阵的对角项,而不是您提到的标准差。如果您的输入是一元高斯函数,那么pv和{}都是对应高斯函数方差的长度为1的向量。在

另外,len(pm)对应于均值向量的维数。在多元正态分布的截面here中,它确实是k。对于一元高斯函数,k为1,对于二元高斯函数,k为2,依此类推。在

以下函数计算任意两个多元正态分布之间的KL散度(协方差矩阵不需要是对角的)(其中numpy作为np导入)

def kl_mvn(m0, S0, m1, S1):
    """
    Kullback-Liebler divergence from Gaussian pm,pv to Gaussian qm,qv.
    Also computes KL divergence from a single Gaussian pm,pv to a set
    of Gaussians qm,qv.
    Diagonal covariances are assumed.  Divergence is expressed in nats.

    - accepts stacks of means, but only one S0 and S1

    From wikipedia
    KL( (m0, S0) || (m1, S1))
         = .5 * ( tr(S1^{-1} S0) + log |S1|/|S0| + 
                  (m1 - m0)^T S1^{-1} (m1 - m0) - N )
    """
    # store inv diag covariance of S1 and diff between means
    N = m0.shape[0]
    iS1 = np.linalg.inv(S1)
    diff = m1 - m0

    # kl is made of three terms
    tr_term   = np.trace(iS1 @ S0)
    det_term  = np.log(np.linalg.det(S1)/np.linalg.det(S0)) #np.sum(np.log(S1)) - np.sum(np.log(S0))
    quad_term = diff.T @ np.linalg.inv(S1) @ diff #np.sum( (diff*diff) * iS1, axis=1)
    #print(tr_term,det_term,quad_term)
    return .5 * (tr_term + det_term + quad_term - N) 

相关问题 更多 >