API 1.3.1.2. TabularDice - Reinforcement-Learning-TU-Vienna/dice_rl_TU_Vienna GitHub Wiki
TabularDice
TabularDice
estimates the policy value $\rho^\pi$ by first approximating the stationary distribution correction $w_{\pi/D}: S \times A \to R_{\geq 0}$ in the tabular setting. It is a model-based estimator that leverages auxiliary statistics derived from logged data. It inherits from ModelBasedTabularOffPE
, but overrides all the necessary base methods.
๐งฎ Mathematical Formulation
In order to obtain the stationary distribution correction, approximates the (modified) backward Bellman equations and solves them directly.
Dual to
TabularDice
,TabularVafe
uses the forward Bellman equations and the state-action value function $Q^\pi$ to estimate the policy value $\rho^\pi$.
For further details, refer to the Bellman Equations wiki page.
๐ Solve
def solve_sdc(self, gamma, **kwargs):
Solves for the stationary distribution correction $w_{\pi / D}$, by using the approximate (modified) backward Bellman equations.
Args:
gamma
(float): Discount factor $\gamma$.projected
(bool, kwargs): Whether to solve the version of the Bellman equations projected onto the support of the dataset distribution $d^D$.modified
(bool, kwargs): Whether to use a modified or standard backward Bellman equations.
Returns:
w_hat
(np.ndarray): Approximated stationary distribution correction $\hat w_{\pi / D}$.info
(dict): Additional metadata from the solver.
def solve_pv(self, w_hat, weighted):
Estimates the policy value $\rho^\pi$ from the approximated stationary distribution correction $\hat w_{\pi / D}$.
Args:
gamma
(float): Discount factor $\gamma$.w_hat
(np.ndarray): Approximated stationary distribution correction $\hat w_{\pi / D}$.weighted
(bool): Whether to normalize by the sum of weightsnp.dot(w, dD)
instead of the number of samplesn
.
Returns:
rho_hat
(float): Estimated policy value $\hat \rho^\pi$.
def solve(self, gamma, **kwargs):
Runs the full evaluation pipeline: estimates $w_{\pi / D}$ via solve_sdc
and computes $\rho^\pi$ via solve_pv
.
๐งช Example
dice = TabularDice(dataset, n_obs=5, n_act=3)
gamma = 0.99
rho_hat, info = dice.solve(
gamma,
projected=True,
modified=True,
weighted=True
)
w_hat = info["w_hat"]