两个多维高斯分布之间的KL散度推导
在深度学习中,我们通常对模型进行抽样并计算与真实样本之间的损失,来估计模型分布与真实分布之间的差异。并且损失可以定义得很简单,比如二范数即可。但是对于已知参数的两个确定分布之间的差异,我们就要通过推导的方式来计算了。
下面对已知均值与协方差矩阵的两个多维高斯分布之间的KL散度进行推导。当然,因为便于分布之间的逼近,Wasserstein distance可能是衡量两个分布之间差异的更好方式,但这个有点难,以后再记录。
首先定义两个$n$维高斯分布如下:
$\begin{aligned} &p(x) = \frac{1}{(2\pi)^{0.5n}|\Sigma|^{0.5}}\exp\left(-\frac{1}{2}(x-\mu)^T\Sigma^{-1}(x-\mu)\right)\\ &q(x) = \frac{1}{(2\pi)^{0.5n}|L|^{0.5}}\exp\left(-\frac{1}{2}(x-m)^T L^{-1}(x-m)\right)\\ \end{aligned}$
需要计算的是:
$\begin{aligned} \text{KL}(p||q) = \text{E}_p\left(\log\frac{p(x)}{q(x)}\right) \end{aligned}$
为了方便说明,下面分步进行推导。首先:
$\begin{aligned} \frac{p(x)}{q(x)} &= \frac {\frac{1}{(2\pi)^{0.5n}|\Sigma|^{0.5}}\exp\left(-\frac{1}{2}(x-\mu)^T\Sigma^{-1}(x-\mu)\right)} {\frac{1}{(2\pi)^{0.5n}|L|^{0.5}}\exp\left(-\frac{1}{2}(x-m)^T L^{-1}(x-m)\right)}\\ &=\left(\frac{|L|}{|\Sigma|}\right)^{0.5}\exp\left(\frac{1}{2}(x-m)^T L^{-1}(x-m) -\frac{1}{2}(x-\mu)^T\Sigma^{-1}(x-\mu)\right) \end{aligned}$
然后加上对数:
$\begin{aligned} \log\frac{p(x)}{q(x)} &= \frac{1}{2}\log\frac{|L|}{|\Sigma|}+ \frac{1}{2}(x-m)^T L^{-1}(x-m) - \frac{1}{2}(x-\mu)^T\Sigma^{-1}(x-\mu) \end{aligned}$
再加上期望:
$\begin{aligned} \text{E}_p\log\frac{p(x)}{q(x)} &=\frac{1}{2}\log\frac{|L|}{|\Sigma|}+ \text{E}_p\left[\frac{1}{2}(x-m)^T L^{-1}(x-m) - \frac{1}{2}(x-\mu)^T\Sigma^{-1}(x-\mu)\right]\\ &=\frac{1}{2}\log\frac{|L|}{|\Sigma|}+ \text{E}_p\text{Tr}\left[\frac{1}{2}(x-m)^T L^{-1}(x-m) - \frac{1}{2}(x-\mu)^T\Sigma^{-1}(x-\mu)\right]\\ \end{aligned}$
第二步是因为结果为标量,可以转换为计算迹的形式。接着由迹的平移不变性得:
$\begin{align} &\frac{1}{2}\log\frac{|L|}{|\Sigma|}+ \text{E}_p\text{Tr} \left[ \frac{1}{2}L^{-1}(x-m)(x-m)^T - \frac{1}{2}\Sigma^{-1}(x-\mu)(x-\mu)^T \right]\\ = &\frac{1}{2}\log\frac{|L|}{|\Sigma|}+ \frac{1}{2}\text{E}_p\text{Tr} \left(L^{-1}(x-m)(x-m)^T\right) - \frac{1}{2}\text{E}_p\text{Tr} \left(\Sigma^{-1}(x-\mu)(x-\mu)^T\right) \\ = &\frac{1}{2}\log\frac{|L|}{|\Sigma|}+ \frac{1}{2}\text{E}_p\text{Tr} \left(L^{-1}(x-m)(x-m)^T\right) - \frac{n}{2} \end{align}$
其中最后一项是因为,首先期望与迹可以调换位置,然后$(x-\mu)(x-\mu)^T$在分布$p$下的期望就是对应的协方差矩阵$\Sigma$,于是得到一个$n$维单位阵,再计算单位阵的迹为$n$。
接下来,把中间项提出来推导,得:
$\begin{align} &\frac{1}{2}\text{E}_p\text{Tr} \left(L^{-1}(x-m)(x-m)^T\right)\\ =&\frac{1}{2}\text{Tr}\left(L^{-1}\text{E}_p \left(xx^T-xm^T-mx^T+mm^T \right) \right) \\ =&\frac{1}{2}\text{Tr}\left(L^{-1} \left(\Sigma +\mu\mu^T-2\mu m^T+mm^T \right) \right) \end{align}$
其中$\text{E}_p(xx^T) = \Sigma + \mu\mu^T$推导如下:
$\begin{aligned} \Sigma &= \text{E}_p\left[(x-\mu)(x-\mu)^T\right]\\ &= \text{E}_p\left(xx^T-x\mu^T-\mu x^T+\mu\mu^T\right)\\ &= \text{E}_p\left(xx^T\right)-2\text{E}_p\left(x\mu^T\right)+\mu\mu^T \\ &= \text{E}_p\left(xx^T\right)-\mu\mu^T \\ \end{aligned}$
接着推导$(6)$式:
$\begin{aligned} &\frac{1}{2}\text{Tr}\left(L^{-1} \left(\Sigma +\mu\mu^T-2\mu m^T+mm^T \right) \right) \\ = &\frac{1}{2}\text{Tr}\left(L^{-1}\Sigma +L^{-1} (\mu-m)(\mu-m)^T \right) \\ = &\frac{1}{2}\text{Tr}\left(L^{-1}\Sigma\right)+ \frac{1}{2}(\mu-m)L^{-1}(\mu-m)^T \\ \end{aligned}$
最后代回$(3)$式,得到最终结果:
$\begin{aligned} \text{E}_p\log\frac{p(x)}{q(x)} =&\frac{1}{2}\left\{ \log\frac{|L|}{|\Sigma|}+ \text{Tr}\left(L^{-1}\Sigma\right)+ (\mu-m)L^{-1}(\mu-m)^T - n \right\} \end{aligned}$
- 多变量高斯分布之间的KL散度(KL Divergence)
- 两个高斯分布乘积的理论推导
- 贝叶斯滤波(四)两个高斯分布函数相乘、卷积推导
- 计算出任意两个日期之间相隔的天数
- iOS开发 两个视图之间值传递的常用方法<五>
- easyUI中两个layout之间元素的拖拽(draggable)或节点被覆盖的原因
- 有两个地方,用到了javabean对象和属性字符串值之间的转换
- Update 两个表之间数据更新
- VFP_获得:月天数.月初日期.月末日期及两个日期之间天数.月数.年数(十豆三)
- 两个Zimbra邮件系统之间的LDAP认证
- (6)多个线程 之间共享数据的方式探讨(设计4个线程,其中两个线程每次对j增加1,另外两个线程对j每次减少1 )
- Oracle 两个日期之间的时间间隔
- ------如何用语句在两个数据库之间复制存储过程----
- 去除两个 inline-block 之间的间距
- 【PHP原生】计算两个已知经纬度之间的距离
- 如何计算两个空间向量之间的转角
- * Java 两个 Java bean 之间的赋值
- 实现两个Mysql数据库之间的主从同步
- 两个页面之间传参数(包括单例模式,sugue,tableView,代理模式)
- 两个Activity之间传递List<T>数据