KL散度和JS散度

【按:组会上师姐用来比较数据的概率分布模型时,用到了这两个指标。会后,用DeepSeek查询了一下这两个指标的定义和计算方法,浅浅记录一下,以备未来回顾。】


定义

KL散度(Kullback-Leibler Divergence)和JS散度(Jensen-Shannon Divergence)是衡量两个概率分布差异的常用工具,广泛应用于机器学习、信息论和统计学等领域。以下是它们的详细说明:


1. KL散度(Kullback-Leibler Divergence)

定义

KL散度衡量分布 $P$ 相对于分布 $Q$ 的相对熵(信息损失),表示用 $Q$ 近似 $P$ 时丢失的信息量。

  • 非对称性:$D_{KL}(P | Q) \neq D_{KL}(Q | P)$。
  • 非负性:$D_{KL}(P | Q) \geq 0$,当且仅当 $P = Q$ 时为零。

计算公式

  • 离散分布
    $$
    D_{KL}(P | Q) = \sum_{x} P(x) \log \frac{P(x)}{Q(x)}
    $$
  • 连续分布
    $$
    D_{KL}(P | Q) = \int_{-\infty}^{\infty} p(x) \log \frac{p(x)}{q(x)} dx
    $$

应用场景

  • 模型训练:在变分自编码器(VAE)中,用于约束隐变量分布接近先验分布。
  • 信息编码:优化编码方案时最小化信息损失。
  • 强化学习:策略梯度方法中限制策略更新的幅度。

注意事项

  • 当 $Q(x)=0$ 且 $P(x)>0$ 时,KL散度无定义(需平滑处理)。
  • 不对称性可能导致不同结果(如 $P$ 是真实分布时更常用 $D_{KL}(P | Q)$)。

2. JS散度(Jensen-Shannon Divergence)

定义

JS散度是KL散度的对称化版本,基于 $P$ 和 $Q$ 的平均分布 $M = \frac{P + Q}{2}$ 计算。

  • 对称性:$D_{JS}(P | Q) = D_{JS}(Q | P)$。
  • 有界性:$0 \leq D_{JS} \leq \log 2$,便于比较不同分布间的差异。

计算公式

$$
D_{JS}(P | Q) = \frac{1}{2} D_{KL}(P | M) + \frac{1}{2} D_{KL}(Q | M), \quad M = \frac{P + Q}{2}
$$

应用场景

  • 生成对抗网络(GAN):早期用于衡量生成分布与真实分布的差异(后因梯度问题被Wasserstein距离取代)。
  • 聚类分析:评估聚类结果与真实标签的相似性。
  • 生物学:比较基因表达分布的差异。

优点

  • 对称性使其更适合需要双向比较的场景。
  • 避免KL散度的无定义问题(因 $M$ 融合了 $P$ 和 $Q$)。

3. 核心区别

特性 KL散度 JS散度
对称性 非对称 对称
取值范围 $[0, +\infty)$ $[0, \log 2]$
鲁棒性 对 $Q$ 的零值敏感 对零值鲁棒
计算复杂度 直接计算 需计算平均分布 $M$

4. 实例说明

假设两个离散分布:

  • $P = [0.2, 0.5, 0.3]$
  • $Q = [0.1, 0.6, 0.3]$

KL散度计算
$$
D_{KL}(P | Q) = 0.2 \log \frac{0.2}{0.1} + 0.5 \log \frac{0.5}{0.6} + 0.3 \log \frac{0.3}{0.3} \approx 0.036
$$

JS散度计算

  1. 计算平均分布 $M = \frac{P+Q}{2} = [0.15, 0.55, 0.3]$
  2. 计算两项KL散度后取平均:
    $$
    D_{JS}(P | Q) \approx \frac{1}{2} (0.025 + 0.047) \approx 0.036
    $$

5. 总结

  • KL散度:适用于需要单向衡量信息损失的场景(如模型优化)。
  • JS散度:适合对称性比较,但可能因梯度消失问题在GAN中受限。
  • 选择建议:优先考虑对称性和鲁棒性时选JS;需要明确方向性时选KL。

在python和R中计算KL散度和JS散度

在Python和R中计算KL散度(Kullback-Leibler Divergence)和JS散度(Jensen-Shannon Divergence)可以通过内置函数或第三方库实现。以下是具体方法和代码示例:


1. Python实现

KL散度

方法1:使用scipy.stats.entropy
1
2
3
4
5
6
7
8
9
10
import numpy as np
from scipy.stats import entropy

# 定义两个离散概率分布P和Q(需归一化)
P = np.array([0.5, 0.3, 0.2])
Q = np.array([0.4, 0.4, 0.2])

# 计算KL散度(注意顺序:P对Q的KL散度)
kl_divergence = entropy(P, Q, base=2) # base=2表示以2为底(结果单位为比特)
print("KL Divergence (P || Q):", kl_divergence)

注意

  • 输入必须是概率分布(和为1),否则需手动归一化。
  • 若Q中存在0而P中对应位置非0,结果会为inf(需平滑处理)。
方法2:手动计算
1
2
3
4
5
6
7
8
def kl_divergence(P, Q):
# 避免log(0)的情况(添加微小值平滑)
epsilon = 1e-10
P = P + epsilon
Q = Q + epsilon
return np.sum(P * np.log2(P / Q))

print("KL Divergence (Manual):", kl_divergence(P, Q))

JS散度

方法1:基于KL散度计算
1
2
3
4
5
6
7
8
def js_divergence(P, Q):
# 计算平均分布M
M = 0.5 * (P + Q)
# 计算JS散度
js = 0.5 * entropy(P, M, base=2) + 0.5 * entropy(Q, M, base=2)
return js

print("JS Divergence:", js_divergence(P, Q))
方法2:使用scipy.spatial.distance.jensenshannon
1
2
3
4
5
6
from scipy.spatial.distance import jensenshannon

# 直接计算JS散度(返回值为距离,需平方后才是散度)
js_distance = jensenshannon(P, Q, base=2)
js_divergence = js_distance ** 2
print("JS Divergence (scipy):", js_divergence)

注意

  • jensenshannon返回的是JS距离(即JS散度的平方根),需平方得到散度值。

2. R实现

KL散度

方法1:使用philentropy
1
2
3
4
5
6
7
8
9
install.packages("philentropy")  # 首次使用需安装
library(philentropy)

P <- c(0.5, 0.3, 0.2)
Q <- c(0.4, 0.4, 0.2)

# 计算KL散度(注意参数unit="log2"表示以2为底)
kl_divergence <- KL(rbind(P, Q), unit = "log2")
print(paste("KL Divergence (P || Q):", kl_divergence))
方法2:手动计算
1
2
3
4
5
6
7
8
kl_divergence <- function(P, Q) {
epsilon <- 1e-10
P <- P + epsilon
Q <- Q + epsilon
sum(P * log2(P / Q))
}

print(paste("KL Divergence (Manual):", kl_divergence(P, Q)))

JS散度

方法1:基于KL散度计算
1
2
3
4
5
6
js_divergence <- function(P, Q) {
M <- 0.5 * (P + Q)
0.5 * KL(rbind(P, M), unit = "log2") + 0.5 * KL(rbind(Q, M), unit = "log2")
}

print(paste("JS Divergence:", js_divergence(P, Q)))
方法2:使用philentropy包的JSD
1
2
3
# JSD函数默认以e为底,需调整
js_divergence <- JSD(rbind(P, Q), unit = "log2")
print(paste("JS Divergence (philentropy):", js_divergence))

3. 注意事项

  1. 输入归一化:确保输入的概率分布和为1(可通过P = P / np.sum(P)P <- P / sum(P)处理)。
  2. 零值处理:添加微小值(如1e-10)避免除零或对数零错误。
  3. 对称性:JS散度对称,KL散度不对称(注意顺序)。
  4. 连续分布:若需处理连续分布,需通过数值积分(如scipy.integrate.quad或R的integrate)。

4. 完整示例(Python)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import numpy as np
from scipy.stats import entropy
from scipy.spatial.distance import jensenshannon

# 定义分布
P = np.array([0.5, 0.3, 0.2])
Q = np.array([0.4, 0.4, 0.2])

# KL散度
kl_scipy = entropy(P, Q, base=2)
print("KL (Scipy):", kl_scipy)

# JS散度
js_scipy = jensenshannon(P, Q, base=2) ** 2
print("JS (Scipy):", js_scipy)

# 手动验证
def kl_div(P, Q):
eps = 1e-10
P, Q = P + eps, Q + eps
return np.sum(P * np.log2(P / Q))

def js_div(P, Q):
M = 0.5 * (P + Q)
return 0.5 * kl_div(P, M) + 0.5 * kl_div(Q, M)

print("KL (Manual):", kl_div(P, Q))
print("JS (Manual):", js_div(P, Q))

5. 常见问题

  • 结果不一致:不同库可能对底数(自然对数lnlog2)或零值处理方式不同,需统一标准。
  • 高维数据:对于高维分布(如图像生成),建议使用Wasserstein距离或MMD(如torchmetrics库中的实现)。