API 1.3.1.1. TabularVafe - Reinforcement-Learning-TU-Vienna/dice_rl_TU_Vienna GitHub Wiki

TabularVafe

TabularVafe estimates the policy value $\rho^\pi$ by first approximating the state-action value function $Q^\pi: S \times A \to R$ in the tabular setting. It is a model-based estimator that leverages auxiliary statistics derived from logged data. It inherits from ModelBasedTabularOffPE, but overrides all the necessary base methods.

๐Ÿงฎ Mathematical Formulation

In order to obtain the state-action value function, TabulaVafe approximates the forward Bellman equations and solves them directly.

Dual to TabulaVafe, TabularDice uses the backward Bellman equations and the stationary distribution correction $w_{\pi / D}$ to estimate the policy value $\rho^\pi$.

For further details, refer to the Bellman Equations wiki page.

๐Ÿš€ Solve

def solve_vaf(self, gamma, **kwargs):

Solves for the state-action value function $Q^\pi$, by using the approximate forward Bellman equations.

Args:

  • gamma (float): Discount factor $\gamma$.
  • projected (bool, kwargs): Whether to solve the version of the Bellman equations projected onto the support of the dataset distribution $d^D$.

Returns:

  • Q_hat (np.ndarray): Approximated state-action value function $\hat Q^\pi$.
  • info (dict): Additional metadata from the solver.
def solve_pv(self, gamma, Q_hat, info):

Estimates the policy value $\rho^\pi$ from the approximated initial state-action distribution and state-action value function $\hat Q^\pi$.

Args:

  • gamma (float): Discount factor $\gamma$.
  • Q_hat (np.ndarray): Approximated state-action value function $\hat Q^\pi$.
  • info (dict): Additional metadata from the solver.

Returns:

  • rho_hat (float): Estimated policy value $\hat \rho^\pi$.
def solve(self, gamma, **kwargs):

Runs the full evaluation pipeline: estimates $Q^\pi$ via solve_vaf and computes $\rho^\pi$ via solve_pv.

๐Ÿงช Example

vafe = TabularVafe(dataset, n_obs=5, n_act=3)
gamma = 0.99

rho_hat, info = vafe.solve(gamma, projected=True)
Q_hat = info["Q_hat"]