API 1.2. TabularSarsa - Reinforcement-Learning-TU-Vienna/dice_rl_TU_Vienna GitHub Wiki

TabularSarsa

TabularSarsa estimates the policy value $\rho^\pi$ by first approximating the state-action value function $Q^\pi: S \times A \to R$ in the tabular setting. It is an offline, off-policy version of the Expected SARSA algorithm. It inherits from TabularOffPE, but overrides all the necessary base methods.

๐Ÿงฎ Mathematical Formulation

In order to obtain the state-action value function, it executes sequential Expected SARSA updates, using dataset samples opposed to online environment interactions:

$$ \begin{align*} Q^{(i+1)}(s, a) & = Q^{(i)}(s, a) + \alpha \left ( r + \gamma E_{a' \sim \pi(s') }[ Q^{(i)}(s', a') ] - Q^{(i)}(s, a) \right ) \ & = (1 - \alpha) Q^{(i)}(s, a) + \alpha \left ( r + \gamma E_{a' \sim \pi(s') }[ Q^{(i)}(s', a') ] \right ). \end{align*} $$

For further details, refer to Sutton & Barto 2018.

๐Ÿ—๏ธ Constructor

def __init__(self, dataset, n_obs, n_act):

Args:

  • dataset (pd.DataFrame): Dataset with columns:
    • id (int): Episode identifier.
    • t (int): Time step index $t$.
    • obs_init (int or NDArray[float]): Initial observation (state) $s_0$.
    • obs (int or NDArray[float]): Current observation (state) $s$.
    • act (int): Action $a$.
    • rew (float): Reward $R(s, a, s')$.
    • obs_next (int or NDArray[float]): Next observation (state) $s'$.
    • probs_init_evaluation or probs_init (NDArray[float]): Action probabilities under the target policy at the initial state $\pi(\cdot \mid s_0)$.
    • probs_next_evaluation or probs_next (NDArray[float]): Action probabilities under the target policy at the next state $\pi(\cdot \mid s')$.
  • n_obs (int): Number of states $|S|$.
  • n_act (int): Number of actions $|A|$.

๐Ÿš€ Solve

def solve_vaf(self, gamma, **kwargs):

Solves for the state-action value function $Q^\pi$, by dispatching the learning loop to either a fixed number of steps or epochs.

Args:

  • gamma (float): Discount factor $\gamma$.
  • n_steps (int, kwargs): Number of gradient steps.
  • n_epochs (int, kwargs): Number of full-dataset passes.
  • by (str, kwargs): "samples" or "episodes" determines the update strategy.
  • alpha (float, kwargs): Learning rate $\alpha$.
  • shuffle (bool, kwargs): Shuffle dataset after each epoch.
  • verbosity (int, kwargs): Level of progress logging.
  • get_metrics (callable, kwargs): A metric collector function, taking (Q, rho) as arguments and returning a dictionary with str as keys and int or float as values.

Returns:

  • Q_hat (np.ndarray): Approximated state-action value function $\hat Q^\pi$.
  • info (dict): Dictionary containing tracked metrics and additional metadata from the solver.

Either n_steps or n_epochs must be provided, but not both.

def solve_pv(self, gamma, Q_hat):

Estimates the policy value $\rho^\pi$ from the approximated state-action value function $\hat Q^\pi$ via

$$ \hat \rho^\pi \doteq (1 - \gamma) \frac{1}{n} \sum_{i=1}^n E_{ a_0 \sim \pi(s_0) }[ Q(s_0, a_0) ]. $$

Depending on whether the dataset is provided purely via samples or episodes, i.e. "t" is part of dataset.columns, all samples or merely those indexed with dataset["t"] == 0 are used.

def solve(self, gamma, **kwargs):

Runs the full evaluation pipeline: estimates $Q^\pi$ via solve_vaf and computes $\rho^\pi$ via solve_pv.

โš™๏ธ Utility

def get_value(self, vector, obs, act):

Returns the state-action value for a specific state-action pair from a flat Q-vector.

Args:

  • vector (tensor): Q-value vector $Q$.
  • obs (int or tensor): Observation (state) $s$.
  • act (int or tensor): Action $a$.

Returns:

  • value (float or tensor): Q-value $Q(s, a)$.

Make sure that either obs and act are both int or tensor.

def get_average_value(self, vector, obs, probs):

Computes the expected state-action value for a state from a flat Q-vector.

Args:

  • vector (tensor): Q-value vector $Q$.
  • obs (int or tensor): Observation (state) $s$.
  • probs (np.ndarray): Array of action probabilities under the target policy $\pi(a \mid s)$.

Returns:

  • average_value (float or tensor): expected Q-value $E_{a \sim \pi(s)}[Q(s, a)]$.

Make sure that probs are shaped according to obs, i.e. probs.shape = obs.shape + n_act.

def update_Q_hat_by_sample(self, Q_hat, sample, gamma, alpha):

Performs a SARSA update using a single sample.

def update_Q_hat_by_episode(self, Q_hat, episode, gamma, alpha):

Performs SARSA updates over all samples in an episode (in reverse).

def solve_vaf_n_steps(self, gamma, n_steps, by, alpha, get_metrics, verbosity, Q_hat, info):

Runs the value function solver using a given number of steps.

def solve_vaf_n_epochs(self, gamma, n_epochs, by, alpha, get_metrics, shuffle, verbosity, Q_hat, info):

Runs the value function solver using a given number of epochs.

๐Ÿงช Example

sarsa = TabularSarsa(dataset, n_obs=5, n_act=2)

rho_hat, info = sarsa.solve(
    gamma=0.99,
    n_steps=10_000,
    by="samples",
    alpha=0.1,
    shuffle=True,
    get_metrics=lambda Q, rho: {"rho": rho},
    verbosity=1,
)