API 1.3.2.1. TabularDualDice - Reinforcement-Learning-TU-Vienna/dice_rl_TU_Vienna GitHub Wiki
TabularDualDice
estimates the policy value TabularDice
, but overrides the stationary distribution correction solver solve_sdc
.
Note that TabularDualDice
only supports the discounted case, i.e.,
In order to obtain the stationary distribution correction, it minimizes the primal objective of DualDICE, by setting its derivative to zero and solving the resulting system of linear equations. This results in a primal solution
The primal objective function used in DualDICE is:
Minimizing this objective with respect to
The corresponding stationary distribution correction
Similar to
TabularDualDice
,TabularGradientDice
also solves for the stationary distribution correction by minimizing a primal objective, albeit the one from GradientDICE.
For further details, refer to the original paper DualDICE: Behavior-Agnostic Estimation of Discounted Stationary Distribution Corrections
def solve_sdc(self, gamma, **kwargs):
Solves for the stationary distribution correction
Args:
-
gamma
(float): Discount factor$\gamma$ . -
projected
(bool, kwargs): Whether to solve the version of the Bellman equations projected onto the support of the dataset distribution$d^D$ .
Returns:
-
w_hat
(np.ndarray): Approximated stationary distribution correction$\hat w_{\pi / D}$ . -
info
(dict): Contains estimated primal solution$\hat v^\ast$ and additional metadata from the solver.
def solve_pv(self, w_hat, weighted):
Estimates the policy value
Args:
-
gamma
(float): Discount factor$\gamma$ . -
w_hat
(np.ndarray): Approximated stationary distribution correction$\hat w_{\pi / D}$ . -
weighted
(bool): Whether to normalize by the sum of weightsnp.dot(w, dD)
instead of the number of samplesn
.
Returns:
-
rho_hat
(float): Estimated policy value$\hat \rho^\pi$ .
def solve(self, gamma, **kwargs):
Runs the full evaluation pipeline: estimates solve_sdc
and computes solve_pv
.
estimator = TabularDualDice(dataset, n_obs, n_act)
gamma = 0.99
rho_hat, info = estimator.solve(gamma, projected=True, weighted=True)