Lesson 8 of 15

Jensen-Shannon Divergence

Jensen-Shannon Divergence

The Jensen-Shannon divergence (JSD) fixes KL divergence's asymmetry by averaging the KL from each distribution to their mixture:

JSD(PQ)=12DKL(PM)+12DKL(QM)\text{JSD}(P \| Q) = \frac{1}{2} D_{KL}(P \| M) + \frac{1}{2} D_{KL}(Q \| M)

where M=P+Q2M = \frac{P + Q}{2} is the mixture distribution (element-wise average).

Key Properties

  • Symmetric: JSD(PQ)=JSD(QP)\text{JSD}(P \| Q) = \text{JSD}(Q \| P)
  • Bounded: 0JSD10 \leq \text{JSD} \leq 1 when using log2\log_2
  • Square root is a metric: JSD(P,Q)\sqrt{\text{JSD}(P, Q)} satisfies the triangle inequality

Interpretation

JSDMeaning
0P=QP = Q (identical distributions)
1PP and QQ have disjoint supports (completely different)

Jensen-Shannon Distance

The square root of JSD is a proper metric:

JSD-distance(P,Q)=JSD(PQ)\text{JSD-distance}(P, Q) = \sqrt{\text{JSD}(P \| Q)}

import math

def js_divergence(p, q):
    m = [(p[i] + q[i]) / 2 for i in range(len(p))]
    kl_pm = sum(p[i] * math.log2(p[i] / m[i]) for i in range(len(p)) if p[i] > 0)
    kl_qm = sum(q[i] * math.log2(q[i] / m[i]) for i in range(len(q)) if q[i] > 0)
    return 0.5 * kl_pm + 0.5 * kl_qm

p = [0.5, 0.5]
q = [0.25, 0.75]
print(round(js_divergence(p, q), 4))  # 0.0488

Your Task

Implement:

  • js_divergence(p, q)0.5DKL(PM)+0.5DKL(QM)0.5 \cdot D_{KL}(P \| M) + 0.5 \cdot D_{KL}(Q \| M) where M=(P+Q)/2M = (P+Q)/2
  • js_distance(p, q)JSD(PQ)\sqrt{\text{JSD}(P \| Q)}
Python runtime loading...
Loading...
Click "Run" to execute your code.