API 1.3.2.1. TabularDualDice - Reinforcement-Learning-TU-Vienna/dice_rl_TU_Vienna GitHub Wiki

TabularDualDice

TabularDualDice estimates the policy value $\rho^\pi$ by first approximating the stationary distribution correction $w_{\pi/D}: S \times A \to R_{\geq 0}$ in the tabular setting. It is a model-based estimator that leverages auxiliary statistics derived from logged data. It inherits from TabularDice, but overrides the stationary distribution correction solver solve_sdc.

Note that TabularDualDice only supports the discounted case, i.e., $0 < \gamma < 1$.

๐Ÿงฎ Mathematical Formulation

In order to obtain the stationary distribution correction, it minimizes the primal objective of DualDICE, by setting its derivative to zero and solving the resulting system of linear equations. This results in a primal solution $\hat v^\ast$, which is then used to extract the corresponding dual solution $\hat w^\ast$, that is the approximate stationary distribution correction $\hat w_{\pi / D}$.

The primal objective function used in DualDICE is:

$$ J(v) \doteq (1 - \gamma) E_{ (s_0, a_0) \sim d^\pi_0} [ v(s_0, a_0) ] + E_{ (s, a) \sim d^D } \left[ \phi( \gamma P^\pi v(s, a) - v(s, a) ) \right], \quad \phi(x) \doteq |x|^p, \quad p > 1. $$

Minimizing this objective with respect to $v$ yields the solution

$$ v^\ast \doteq \arg \min J(v). $$

The corresponding stationary distribution correction $w_{\pi / D}$ is then extracted using:

$$ w_{\pi / D} = w^\ast = - (I - \gamma P^\pi) v^\ast. $$

Similar to TabularDualDice, TabularGradientDice also solves for the stationary distribution correction by minimizing a primal objective, albeit the one from GradientDICE.

For further details, refer to the original paper DualDICE: Behavior-Agnostic Estimation of Discounted Stationary Distribution Corrections

๐Ÿš€ Solve

def solve_sdc(self, gamma, **kwargs):

Solves for the stationary distribution correction $w_{\pi / D}$, by first computing the primal solution $v^\ast$, which is obtained by minimizing the primal objective of DualDICE.

Args:

  • gamma (float): Discount factor $\gamma$.
  • projected (bool, kwargs): Whether to solve the version of the Bellman equations projected onto the support of the dataset distribution $d^D$.

Returns:

  • w_hat (np.ndarray): Approximated stationary distribution correction $\hat w_{\pi / D}$.
  • info (dict): Contains estimated primal solution $\hat v^\ast$ and additional metadata from the solver.
def solve_pv(self, w_hat, weighted):

Estimates the policy value $\rho^\pi$ from the approximated stationary distribution correction $\hat w_{\pi / D}$.

Args:

  • gamma (float): Discount factor $\gamma$.
  • w_hat (np.ndarray): Approximated stationary distribution correction $\hat w_{\pi / D}$.
  • weighted (bool): Whether to normalize by the sum of weights np.dot(w, dD) instead of the number of samples n.

Returns:

  • rho_hat (float): Estimated policy value $\hat \rho^\pi$.
def solve(self, gamma, **kwargs):

Runs the full evaluation pipeline: estimates $w_{\pi / D}$ via solve_sdc and computes $\rho^\pi$ via solve_pv.

๐Ÿงช Example

estimator = TabularDualDice(dataset, n_obs, n_act)
gamma = 0.99

rho_hat, info = estimator.solve(gamma, projected=True, weighted=True)
โš ๏ธ **GitHub.com Fallback** โš ๏ธ