API 1.3.1.2. TabularDice - Reinforcement-Learning-TU-Vienna/dice_rl_TU_Vienna GitHub Wiki

TabularDice

TabularDice estimates the policy value $\rho^\pi$ by first approximating the stationary distribution correction $w_{\pi/D}: S \times A \to R_{\geq 0}$ in the tabular setting. It is a model-based estimator that leverages auxiliary statistics derived from logged data. It inherits from ModelBasedTabularOffPE, but overrides all the necessary base methods.

๐Ÿงฎ Mathematical Formulation

In order to obtain the stationary distribution correction, approximates the (modified) backward Bellman equations and solves them directly.

Dual to TabularDice, TabularVafe uses the forward Bellman equations and the state-action value function $Q^\pi$ to estimate the policy value $\rho^\pi$.

For further details, refer to the Bellman Equations wiki page.

๐Ÿš€ Solve

def solve_sdc(self, gamma, **kwargs):

Solves for the stationary distribution correction $w_{\pi / D}$, by using the approximate (modified) backward Bellman equations.

Args:

  • gamma (float): Discount factor $\gamma$.
  • projected (bool, kwargs): Whether to solve the version of the Bellman equations projected onto the support of the dataset distribution $d^D$.
  • modified (bool, kwargs): Whether to use a modified or standard backward Bellman equations.

Returns:

  • w_hat (np.ndarray): Approximated stationary distribution correction $\hat w_{\pi / D}$.
  • info (dict): Additional metadata from the solver.
def solve_pv(self, w_hat, weighted):

Estimates the policy value $\rho^\pi$ from the approximated stationary distribution correction $\hat w_{\pi / D}$.

Args:

  • gamma (float): Discount factor $\gamma$.
  • w_hat (np.ndarray): Approximated stationary distribution correction $\hat w_{\pi / D}$.
  • weighted (bool): Whether to normalize by the sum of weights np.dot(w, dD) instead of the number of samples n.

Returns:

  • rho_hat (float): Estimated policy value $\hat \rho^\pi$.
def solve(self, gamma, **kwargs):

Runs the full evaluation pipeline: estimates $w_{\pi / D}$ via solve_sdc and computes $\rho^\pi$ via solve_pv.

๐Ÿงช Example

dice = TabularDice(dataset, n_obs=5, n_act=3)
gamma = 0.99

rho_hat, info = dice.solve(
    gamma,
    projected=True,
    modified=True,
    weighted=True
)
w_hat = info["w_hat"]