API 1.0.0. Bellman Equations - Reinforcement-Learning-TU-Vienna/dice_rl_TU_Vienna GitHub Wiki
bellman_equations
(Utility Functions)
This module provides solvers for forward and backward Bellman equations in tabular MDPs, including both analytical and approximate variants. These solvers are used throughout the analytical solvers and model-based algorithms in the project:
🧮 Mathematical Formulation
The quantities involved in the Bellman equations are:
d0
: Initial (state-action) distribution $d^\pi_0$dD
: Dataset (state-action) distribution $d^D$P
: (State-action) transition matrix $P^\pi$r
: Expected reward function $r$
Those are related by means of the following equations:
- The state-action value function $Q^\pi$ solves the forward Bellman equations for $0 < \gamma < 1$:
$$ Q = r + \gamma P^\pi Q. $$
- The stationary distribution $d^\pi$ solves the backward Bellman equations for $0 < \gamma \leq 1$:
$$ d = (1 - \gamma) d_0 + \gamma P^\pi_\top d. $$
- The stationary distribution correction $w_{\pi / D}$ solves the modified backward Bellman equations for $0 < \gamma \leq 1$:
$$ D^D w = (1 - \gamma) d_0^\pi + \gamma P^\pi_\top D^D w. $$
Note that the Bellman equations are linear in the discounted case and turn into an eigenvalue problem for the eigenvalue $1$ in the undiscounted case. In this implementation, we choose the eigenvector of the eigenvalue closest to $1$.
In case we do not have exact environment dynamics at our disposal, we can still infer the quantities above from a given dataset $D$, leading to:
d0_bar
: Empirical initial visitation counts $\bar d^\pi_0$dD_bar
: Empirical dataset visitation counts $\bar d^D$P_bar
: Empirical transition counts $\bar P^\pi$r_bar
: Empirical reward totals $\bar r$
From this, we can derive the estimates:
d0_hat
: Empirical initial distribution $\hat d^\pi_0$dD_hat
: Empirical dataset distribution $\hat d^D$P_hat
: Empirical transition matrix $\hat P^\pi$r_hat
: Empirical expected reward function $\hat r$
We can effectively calculate:
n = sum(dD_bar)
d0_hat = d0_bar / n
dD_hat = dD_bar / n
P_hat = (P_bar.T / dD_bar).T
r_hat = r_bar / dD_bar
For further details, refer to the Bellman Equations wiki page.
🚀 Solve
def solve_forward_bellman_equations(dD, r, P, gamma, projected=False):
Calls solve_standard_forward_bellman_equations
to solve the forward Bellman Equations, but projects the expected reward r
, and transition matrix P
onto the support of the dataset distribution dD
.
Args:
dD
(np.ndarray): Dataset distribution $d^D$.r
(np.ndarray): Expected reward vector $r$.P
(np.ndarray): Transition matrix $P^\pi$.gamma
(float): Discount factor $\gamma$.projected
(bool): Whether to project onto support of $d^D$.
Returns:
Q
(np.ndarray): State-action value function $Q^\pi$.info
(dict): Solver metadata.
def solve_forward_bellman_equations_approximate(dD_bar, r_bar, P_bar, gamma, projected=False):
Calls solve_standard_forward_bellman_equations_approximate
to solve an approximate version of the forward Bellman Equations, but projects the empirical reward totals r_bar
, and empirical transition counts P_bar
onto the support of the empirical dataset visitation dD_bar
.
Args:
dD_bar
(np.ndarray): Empirical dataset visitation counts $\bar d^D$.r_bar
(np.ndarray): Empirical reward totals $\bar r$.P_bar
(np.ndarray): Empirical transition counts $\bar P^\pi$.gamma
(float): Discount factor $\gamma$.projected
(bool): Whether to project onto the support of the dataset distribution $d^D$ before solving the Bellman equations.
Returns:
Q_hat
(np.ndarray): Estimated state-action value function $\hat Q^\pi$.info
(dict): Solver metadata.
def solve_backward_bellman_equations(d0, dD, P, gamma, modified=False, projected=False):
Calls solve_*_backward_bellman_equations
, where *
is standard
or modified
, to solve the (modified) backward Bellman Equations, but projects the initial distribution d0
, expected reward r
, and transition matrix P
onto the support of the dataset distribution dD
.
Args:
d0
(np.ndarray): Initial distribution $d^\pi_0$.dD
(np.ndarray): Dataset distribution $d^D$.P
(np.ndarray): Transition matrix $P^\pi$.gamma
(float): Discount factor $\gamma$.modified
(bool): Whether to use the modified or standard backward Bellman equations.projected
(bool): Whether to project onto the support of the dataset distribution $d^D$ before solving the Bellman equations.
Returns:
w
(np.ndarray): Stationary distribution correction $w_{\pi / D}$.info
(dict): Solver metadata.
def solve_backward_bellman_equations_approximate(d0_bar, dD_bar, P_bar, n, gamma, modified=False, projected=False):
Calls solve_*_backward_bellman_equations
, where *
is standard
or modified
, to solve an approximate version of the (modified) backward Bellman Equations, but projects the empirical initial visitation counts r_bar
, and empirical transition counts P_bar
onto the support of the empirical dataset visitation dD_bar
.
Args:
d0_bar
(np.ndarray): Empirical initial visitation counts $\bar d^\pi_0$.dD_bar
(np.ndarray): Empirical dataset visitation counts $\bar d^D$.P_bar
(np.ndarray): Empirical transition counts $\bar P^\pi$.n
(int): Number of samples $n$.gamma
(float): Discount factor $\gamma$.projected
(bool): Whether to project onto the support of the dataset distribution $d^D$ before solving the Bellman equations.
Returns:
w_hat
(np.ndarray): Estimated stationary distribution correction $\hat w_{\pi / D}$.info
(dict): Solver metadata.
⚙️ Utility
def solve_standard_forward_bellman_equations(r, P, gamma):
Solves the forward Bellman Equations.
def solve_standard_forward_bellman_equations_approximate(dD_bar, r_bar, P_bar, gamma):
Solves an approximate version of the forward Bellman Equations.
def solve_standard_backward_bellman_equations(d0, P, gamma):
Solves the backward Bellman Equations.
def solve_modified_backward_bellman_equations(d0, dD, P, gamma):
Solves the modified backward Bellman Equations.
def solve_standard_backward_bellman_equations_approximate(d0_bar, dD_bar, P_bar, n, gamma):
Solves an approximate version of the backward Bellman Equations.
def solve_modified_backward_bellman_equations_approximate(d0_bar, dD_bar, P_bar, n, gamma):
Solves an approximate version of the modified backward Bellman Equations.
🔬 Diagnostics
def test_avf(gamma, Q, P, r):
Prints the average of the forward Bellman equation residual:
$$ r + \gamma P^\pi Q - Q. $$
def test_sd(gamma, d, d0, P):
Prints the average of the backward Bellman equation residual:
$$ (1 - \gamma) d^\pi_0 + \gamma P^\pi d - d. $$
Checks properties of the stationary distribution:
$$ d \geq 0, \quad \sum d = 1. $$
def test_sdc(gamma, w, d0, dD, P):
Prints the average of the modified backward Bellman equation residual:
$$ (1 - \gamma) d_0 + \gamma P^\pi_\top D^D w - D^D w. $$
Checks properties of the stationary distribution correction:
$$ w \geq 0, \quad \langle w, d^D \rangle = 1. $$