API 2.0. NeuralDice - Reinforcement-Learning-TU-Vienna/dice_rl_TU_Vienna GitHub Wiki
NeuralDice
(Abstract Base Class)
The NeuralDice
class defines a unified interface for continuous stationary distribution correction estimation (DICE) algorithms. It is meant to be extended by specific continuous DICE estimators, such as:
It includes utilities for
- preprocessing input data,
- defining neural network architectures,
- computing losses, and
- running training loops.
๐๏ธ Constructor
def __init__(
self,
gamma, seed, batch_size,
learning_rate, hidden_dimensions,
obs_min, obs_max, n_act, obs_shape,
dataset, preprocess_obs=None, preprocess_act=None, preprocess_rew=None,
dir=None, get_recordings=None, other_hyperparameters=None, save_interval=100):
Args:
gamma
(float): Discount factor $\gamma$.seed
(int): Random seed for reproducibility.batch_size
(int): Number of samples per batch.learning_rate
(float or tf.keras.optimizers.schedules.LearningRateSchedule): Learning rate.hidden_dimensions
(tuple): Number of neurons per hidden layer in the neural networks.obs_min
(np.ndarray): Minimum bounds for observations (states) $s_\min$.obs_max
(np.ndarray): Maximum bounds for observations (states) $s_\max$.n_act
(int): Number of actions $|A|$.obs_shape
(tuple): Shape of the observation space $S$.dataset
(pd.DataFrame): Dataset with columns:obs_init
(int or NDArray[float]): Initial observation (state) $s_0$.obs
(int or NDArray[float]): Current observation (state) $s$.act
(int): Action $a$.rew
(float): Reward $R(s, a, s')$.obs_next
(int or NDArray[float]): Next observation (state) $s'$.probs_next_evaluation
orprobs_next
(NDArray[float]): Action probabilities under the target policy at the next state $\pi(\cdot \mid s')$.
preprocess_obs
(callable, optional): Preprocessing function for observations (states).preprocess_act
(callable, optional): Preprocessing function for actions.preprocess_rew
(callable, optional): Preprocessing function for rewards.dir
(str, optional): Directory for saving logs.get_recordings
(callable, optional): Function to get additional logging information.other_hyperparameters
(dict, optional): Any other (algorithm specific) hyperparameters to label the logs.save_interval
(int, optional): Interval (in steps) at which to save logs and model.
All
preprocess_*
functions should take (a pd.Series of) anobs
,act
, orrew
as an argument and return a preprocessed version. Forobs
,self.preprocess_probs
is applied by default, sinceprobs
is usually of the same type asobs
. Foract
andrew
, no preprocessing is applied by default.
Define
get_recordings
bydef get_recordings( estimator, obs_init, obs, act, obs_next, probs_init, probs_next, values, loss, gradients, pv_s, pv_w, ):
and let it return a dict[str, float]. More information on the arguments is found below.
Creates a dict self.hyperparameters
with
"name"
(str): specific name of the child class estimator,"gamma"
,"seed"
,"batch_size"
,"learning_rate"
,"hidden_dimensions"
,"other"
(dict):other_hyperparameters
.
Then proceeds to set the seed and call self.set_up_recording
and self.set_up_networks
.
def set_up_recording(self):
Initializes logging and recording setup at self.save_dir
, including TensorBoard writers and unique run identifiers and generates a unique evaluation id self.id
via datetime.now().isoformat()
.
def save_hyperparameters(self):
Saves the model's hyperparameters and unique evaluation id to a json file evaluation.json
at self.dir
.
def set_up_networks(self):
Initializes the value networks v
and w
and their corresponding SGD
optimizers.
๐ฆ Properties
@property
@abstractmethod
def __name__(self):
Should return the name of the specific estimator class.
@property
def save_dir(self):
Joins directory self.dir
with the evaluation id self.id
, provided that self.id
is not None
.
@property
def output_activation_fn_dual(self):
Returns the activation function applied to the output of the dual value network. Defaults to tf.identity
, but should be overwritten in a child class if necessary.
๐ Solve
def evaluate_loop(self, n_steps, verbosity=1, pbar_keys=None):
Runs the training loop over a specified number of steps, recording and logging progress.
Args:
n_steps
(int): Number of training steps.verbosity
(int): Verbosity level for logging.pbar_keys
(list[str], optional): Keys to display in the progress bar.
def solve_pv(self, weighted):
Estimates the policy value $\rho^\pi$ from the approximated stationary distribution correction $\hat w_{\pi / D}$.
Args:
weighted
(bool): Whether to normalize by the sum of weights $\sum_{i=1}^n \hat w_{\pi / D}(s_i, a_i)$ instead of the number of samples $n$.
Returns:
pv
(float): Estimated policy value $\hat \rho^\pi$.
๐ต Loss
@abstractmethod
def get_loss(self, v_init, v, v_next, w):
Should compute and return the training loss $J$ for SGDA optimization.
Args:
v_init
(tf.Tensor): Initial primal value $v(s_0, a_0)$.v
(tf.Tensor): Current primal value $v(s, a)$.v_next
(tf.Tensor): Next primal value $v(s', a')$.w
(tf.Tensor): Current dual value $w(s, a)$.
Returns:
loss
(tf.Tensor): Loss scalar $J(v, w)$.
@tf.function(jit_compile=True)
def get_gradients(
self,
obs_init, obs, act, obs_next,
probs_init, probs_next,
batch_size):
Computes gradients of the loss with respect to all trainable variables using SGDA.
Args:
obs_init
(tf.Tensor): Initial observations (states) $s_0$.obs
(tf.Tensor): Current observations (states) $s$.act
(tf.Tensor): Actions $a$.obs_next
(tf.Tensor): Next observations (states) $s'$.probs_init
(tf.Tensor): Array of initial action probabilities under the target policy $\pi(a_0 \mid s_0)$.probs_next
(tf.Tensor): Array of next action probabilities under the target policy $\pi(a' \mid s')$.batch_size
(int): Batch size.
Returns:
result
(tuple):(values, loss, gradients)
.
โ๏ธ Utility
def preprocess_obs(self, obs):
Applies preprocessing to raw observations (states) and converts them to a TensorFlow tensor.
Args:
obs
(float or tensor): Raw observation (state) data.
Returns:
obs_preprocessed
(tf.Tensor): Preprocessed observation (state) tensor.
def preprocess_act(self, obs):
Applies preprocessing to raw actions and converts them to a TensorFlow tensor.
Args:
act
(int or tensor): Raw action data.
Returns:
act_preprocessed
(tf.Tensor): Preprocessed action tensor.
def preprocess_rew(self, obs):
Applies preprocessing to raw rewards and converts them to a TensorFlow tensor.
Args:
rew
(float or tensor): Raw reward data.
Returns:
rew_preprocessed
(tf.Tensor): Preprocessed reward tensor.
def preprocess_probs(self, probs):
Applies preprocessing to raw action probabilities and converts them to a stacked TensorFlow float tensor.
Args:
probs
(float or tensor): Raw action probabilities.
Returns:
probs_preprocessed
(tf.Tensor): Preprocessed action probabilities as a tensor.
def get_value(self, network, obs, act):
Returns the state-action value for a specific state-action pair from a value network.
Args:
network
(ValueNetwork): Network to query, $v$ or $w$.obs
(tf.Tensor): Preprocessed observation (state) tensor $s$.act
(tf.Tensor): Preprocessed action tensor $a$.
Returns:
value
(tf.Tensor): Output from the network, $v(s, a)$ or $w(s, a)$.
def get_average_value(self, network, obs, probs):
Computes the expected state-action value for an observation (state) from a value network.
Args:
network
(ValueNetwork): Value network to query, $v$ or $w$.obs
(tf.Tensor): Observations (states) $s$.probs
(tf.Tensor): Array of action probabilities under the target policy $\pi(a \mid s)$.
Returns:
average_value
(tf.Tensor): Expected output from the network, $E_{a \sim \pi(s)}[v(s, a)]$ or $E_{a \sim \pi(s)}[w(s, a)]$.
def get_values(
self,
obs_init, obs, act, obs_next,
probs_init, probs_next):
Queries the relevant networks to return all value components needed for loss computation.
Args:
obs_init
(tf.Tensor): Initial observations (states) $s_0$.obs
(tf.Tensor): Current observations (states) $s$.act
(tf.Tensor): Actions $a$.obs_next
(tf.Tensor): Next observations (states) $s'$.probs_init
(tf.Tensor): Array of initial action probabilities under the target policy $\pi(a_0 \mid s_0)$.probs_next
(tf.Tensor): Array of next action probabilities under the target policy $\pi(a' \mid s')$.
Returns:
values
(tuple): network outputs(v_init, v, v_next, w)
.
def evaluate_step(
self,
obs_init, obs, act, obs_next,
probs_init, probs_next, batch_size):
Performs a single optimization step using a batch from the dataset.
Args:
obs_init
(tf.Tensor): Initial observations (states) $s_0$.obs
(tf.Tensor): Current observations (states) $s$.act
(tf.Tensor): Actions $a$.obs_next
(tf.Tensor): Next observations (states) $s'$.probs_init
(tf.Tensor): Array of initial action probabilities under the target policy $\pi(a_0 \mid s_0)$.probs_next
(tf.Tensor): Array of next action probabilities under the target policy $\pi(a' \mid s')$.batch_size
(int): Batch size.
Returns:
result
(tuple):(values, loss, gradients)
.
def get_sample(self):
Samples a training batch from the dataset and applies preprocessing.
Returns:
sample
(tuple):(obs_init, obs, act, obs_next, probs_init, probs_next)
.
๐งช Example
from some_module import SomeNeuralDiceChildClass
estimator = SomeNeuralDiceChildClass(
gamma=0.99,
seed=0,
batch_size=64,
learning_rate=1e-3,
hidden_dimensions=(64, 64),
obs_min=obs_min,
obs_max=obs_max,
n_act=4,
obs_shape=(8,),
dataset=df,
dir="./logs"
)
estimator.evaluate_loop(n_steps=10_000)
rho_hat = estimator.solve_pv(weighted=True)