KL divergence - chunhualiao/public-docs GitHub Wiki

What is KL Divergence?

Kullback-Leibler (KL) Divergence is a measure of how one probability distribution differs from another. It quantifies the amount of information lost when using an approximate distribution ( Q(x) ) instead of the true distribution ( P(x) ).

KL divergence is not symmetric, meaning:

D_{KL}(P || Q) \neq D_{KL}(Q || P)

It is always non-negative and equals zero only when ( P(x) = Q(x) ) for all ( x ).


Mathematical Definition

For discrete probability distributions ( P(x) ) and ( Q(x) ), KL divergence is defined as:

D_{KL}(P || Q) = \sum_{x \in X} P(x) \log \frac{P(x)}{Q(x)}

For continuous distributions, the sum becomes an integral:

D_{KL}(P || Q) = \int_{-\infty}^{\infty} P(x) \log \frac{P(x)}{Q(x)} \,dx

Breaking Down Each Term

  • $P(x)$: True (reference) distribution.
  • $Q(x)$: Approximate distribution.
  • $\frac{P(x)}{Q(x)}$: How much ( Q(x) ) underestimates or overestimates ( P(x) ).
  • $\log \frac{P(x)}{Q(x)}$: Measures the divergence at each ( x ), with larger values meaning greater differences.

Explanation of KL Divergence Terms

KL divergence is defined in a way that quantifies the difference between two probability distributions, ensuring meaningful comparisons in various applications, including reinforcement learning. The table below explains why each term is defined as it is.


Table: Explanation of KL Divergence Terms

Term Mathematical Representation Why is it Defined This Way?
KL Divergence Formula ( D_{KL}(P
True Distribution ( P(x) ) Represents the actual distribution of data (ground truth).
Approximate Distribution ( Q(x) ) Represents an estimated or learned distribution.
Probability Ratio ( \frac{P(x)}{Q(x)} ) Measures how much ( Q(x) ) underestimates or overestimates ( P(x) ).
Log of Probability Ratio ( \log \frac{P(x)}{Q(x)} ) Converts ratio differences into a logarithmic scale for better interpretability.
Weighting by True Probability ( P(x) \log \frac{P(x)}{Q(x)} ) Ensures that more probable events contribute more to the divergence.
Summation Over All Events ( \sum_{x} P(x) \log \frac{P(x)}{Q(x)} ) Computes an aggregate measure of how different ( P ) and ( Q ) are.
Expectation Formulation ( D_{KL}(P
Non-Negativity ( D_{KL}(P
Asymmetry ( D_{KL}(P

Key Insights

  1. Why Use a Log Function?

    • The log function transforms multiplicative differences (ratios) into additive differences, making divergence easier to interpret.
    • Small differences become more prominent, helping detect model misalignment early.
  2. Why Weight by ( P(x) )?

    • Events that are more common in ( P ) contribute more to the divergence measure.
    • Rare events don’t dominate the divergence calculation, preventing outliers from misleading optimization.
  3. Why is KL Divergence Non-Negative?

    • The log ratio is zero when ( P(x) = Q(x) ), making the entire sum zero.
    • Since ( \log x \geq 0 ) for ( x \geq 1 ), and probability ratios are always positive, KL divergence cannot be negative.
  4. Why is KL Divergence Asymmetric?

    • ( D_{KL}(P || Q) ) measures how much information is lost when using ( Q ) instead of ( P ).
    • ( D_{KL}(Q || P) ) would measure how much information is lost when reversing the roles.
    • In RL, using one-way divergence ensures a stable learning signal by keeping updates conservative.

Final Takeaway

  • Each term in KL divergence is defined to ensure it correctly quantifies the difference between distributions.
  • The log function, probability weighting, and summation work together to provide a meaningful measure of divergence.
  • KL divergence is used in reinforcement learning (PPO, GRPO) to regularize policy updates and prevent drastic changes.

🚀 This formulation ensures that learning remains stable, interpretable, and effective!

Python Code to Compute KL Divergence

Here’s how you can compute KL divergence using Python:

import numpy as np
from scipy.special import rel_entr  # Relative entropy (KL divergence)

# Define two probability distributions (must sum to 1)
P = np.array([0.4, 0.3, 0.2, 0.1])  # True distribution
Q = np.array([0.3, 0.3, 0.3, 0.1])  # Approximate distribution

# Compute KL divergence
kl_div = np.sum(rel_entr(P, Q))  # rel_entr computes P * log(P / Q) element-wise
print(f"KL Divergence: {kl_div:.4f}")

Expected Output

KL Divergence: 0.0915

This shows how much information is lost when using ( Q(x) ) instead of ( P(x) ).


Why is KL Divergence Used in Reinforcement Learning?

KL divergence is commonly used in policy optimization algorithms like PPO and GRPO for regularization.

1. Prevents Drastic Policy Updates

  • If the new policy ( \pi_{\theta} ) diverges too much from the old policy ( \pi_{\theta_{\text{old}}} ), training can become unstable.
  • KL divergence keeps updates small and controlled.

2. Controls Exploration vs. Exploitation

  • Large KL divergence → Policy is changing too much (unstable).
  • Small KL divergence → Policy is barely improving (too conservative).
  • This helps balance learning new things (exploration) vs. using learned strategies (exploitation).

3. Used in PPO and GRPO

  • PPO (Proximal Policy Optimization):

    J(\theta) = \mathbb{E} \left[ \min \left( r(\theta) A, \text{clip}(r(\theta), 1 - \epsilon, 1 + \epsilon) A \right) - \beta D_{KL}(\pi_{\theta} || \pi_{\text{old}}) \right]
    
    • The KL penalty ( D_{KL}(\pi_{\theta} || \pi_{\text{old}}) ) discourages the new policy from deviating too far.
  • GRPO (Group Relative Policy Optimization):

    J_{GRPO}(\theta) = \mathbb{E} \left[ \sum_{i=1}^{G} \left( \frac{\pi_{\theta}(o_i | q)}{\pi_{\theta_{\text{old}}}(o_i | q)} A_i - \beta D_{KL}(\pi_{\theta} || \pi_{\text{ref}}) \right) \right]
    
    • Here, KL divergence ensures the policy doesn’t shift too far from a reference policy.

Intuition with a Simple Example

Imagine you're training an AI chess bot. If its policy suddenly changes drastically (e.g., from playing defensively to recklessly sacrificing all pieces), KL divergence penalizes such drastic updates, ensuring a more gradual, stable learning process.


Final Thoughts

  • KL divergence measures the difference between two probability distributions.
  • It prevents unstable training by limiting policy updates.
  • It is widely used in RL to ensure controlled learning and avoid over-exploration.
  • GRPO and PPO both rely on KL divergence to regularize updates and improve stability.

🚀 KL divergence is a key tool for improving reinforcement learning algorithms!