API 1.2. TabularSarsa - Reinforcement-Learning-TU-Vienna/dice_rl_TU_Vienna GitHub Wiki
TabularSarsa
TabularSarsa
estimates the policy value $\rho^\pi$ by first approximating the state-action value function $Q^\pi: S \times A \to R$ in the tabular setting. It is an offline, off-policy version of the Expected SARSA algorithm. It inherits from TabularOffPE
, but overrides all the necessary base methods.
๐งฎ Mathematical Formulation
In order to obtain the state-action value function, it executes sequential Expected SARSA updates, using dataset samples opposed to online environment interactions:
$$ \begin{align*} Q^{(i+1)}(s, a) & = Q^{(i)}(s, a) + \alpha \left ( r + \gamma E_{a' \sim \pi(s') }[ Q^{(i)}(s', a') ] - Q^{(i)}(s, a) \right ) \ & = (1 - \alpha) Q^{(i)}(s, a) + \alpha \left ( r + \gamma E_{a' \sim \pi(s') }[ Q^{(i)}(s', a') ] \right ). \end{align*} $$
For further details, refer to Sutton & Barto 2018.
๐๏ธ Constructor
def __init__(self, dataset, n_obs, n_act):
Args:
dataset
(pd.DataFrame): Dataset with columns:id
(int): Episode identifier.t
(int): Time step index $t$.obs_init
(int or NDArray[float]): Initial observation (state) $s_0$.obs
(int or NDArray[float]): Current observation (state) $s$.act
(int): Action $a$.rew
(float): Reward $R(s, a, s')$.obs_next
(int or NDArray[float]): Next observation (state) $s'$.probs_init_evaluation
orprobs_init
(NDArray[float]): Action probabilities under the target policy at the initial state $\pi(\cdot \mid s_0)$.probs_next_evaluation
orprobs_next
(NDArray[float]): Action probabilities under the target policy at the next state $\pi(\cdot \mid s')$.
n_obs
(int): Number of states $|S|$.n_act
(int): Number of actions $|A|$.
๐ Solve
def solve_vaf(self, gamma, **kwargs):
Solves for the state-action value function $Q^\pi$, by dispatching the learning loop to either a fixed number of steps or epochs.
Args:
gamma
(float): Discount factor $\gamma$.n_steps
(int, kwargs): Number of gradient steps.n_epochs
(int, kwargs): Number of full-dataset passes.by
(str, kwargs):"samples"
or"episodes"
determines the update strategy.alpha
(float, kwargs): Learning rate $\alpha$.shuffle
(bool, kwargs): Shuffle dataset after each epoch.verbosity
(int, kwargs): Level of progress logging.get_metrics
(callable, kwargs): A metric collector function, taking(Q, rho)
as arguments and returning a dictionary with str as keys and int or float as values.
Returns:
Q_hat
(np.ndarray): Approximated state-action value function $\hat Q^\pi$.info
(dict): Dictionary containing tracked metrics and additional metadata from the solver.
Either n_steps
or n_epochs
must be provided, but not both.
def solve_pv(self, gamma, Q_hat):
Estimates the policy value $\rho^\pi$ from the approximated state-action value function $\hat Q^\pi$ via
$$ \hat \rho^\pi \doteq (1 - \gamma) \frac{1}{n} \sum_{i=1}^n E_{ a_0 \sim \pi(s_0) }[ Q(s_0, a_0) ]. $$
Depending on whether the dataset is provided purely via samples or episodes, i.e. "t"
is part of dataset.columns
, all samples or merely those indexed with dataset["t"] == 0
are used.
def solve(self, gamma, **kwargs):
Runs the full evaluation pipeline: estimates $Q^\pi$ via solve_vaf
and computes $\rho^\pi$ via solve_pv
.
โ๏ธ Utility
def get_value(self, vector, obs, act):
Returns the state-action value for a specific state-action pair from a flat Q-vector.
Args:
vector
(tensor): Q-value vector $Q$.obs
(int or tensor): Observation (state) $s$.act
(int or tensor): Action $a$.
Returns:
value
(float or tensor): Q-value $Q(s, a)$.
Make sure that either obs
and act
are both int or tensor.
def get_average_value(self, vector, obs, probs):
Computes the expected state-action value for a state from a flat Q-vector.
Args:
vector
(tensor): Q-value vector $Q$.obs
(int or tensor): Observation (state) $s$.probs
(np.ndarray): Array of action probabilities under the target policy $\pi(a \mid s)$.
Returns:
average_value
(float or tensor): expected Q-value $E_{a \sim \pi(s)}[Q(s, a)]$.
Make sure that probs
are shaped according to obs
, i.e. probs.shape = obs.shape + n_act
.
def update_Q_hat_by_sample(self, Q_hat, sample, gamma, alpha):
Performs a SARSA update using a single sample.
def update_Q_hat_by_episode(self, Q_hat, episode, gamma, alpha):
Performs SARSA updates over all samples in an episode (in reverse).
def solve_vaf_n_steps(self, gamma, n_steps, by, alpha, get_metrics, verbosity, Q_hat, info):
Runs the value function solver using a given number of steps.
def solve_vaf_n_epochs(self, gamma, n_epochs, by, alpha, get_metrics, shuffle, verbosity, Q_hat, info):
Runs the value function solver using a given number of epochs.
๐งช Example
sarsa = TabularSarsa(dataset, n_obs=5, n_act=2)
rho_hat, info = sarsa.solve(
gamma=0.99,
n_steps=10_000,
by="samples",
alpha=0.1,
shuffle=True,
get_metrics=lambda Q, rho: {"rho": rho},
verbosity=1,
)