Reinforcement Learning with Stable Baselines 3 - ai-ml-guide/open-source GitHub Wiki
Reinforcement Learning with Stable Baselines3
Build agents that learn by trial and error in simulated environments using Stable Baselines3.
Overview & Use Cases
Reinforcement Learning (RL) enables agents to learn optimal behaviors by interacting with environments. Stable Baselines3 (SB3) provides reliable RL algorithms for research and production.
Use Cases:
- Game AI (Atari, chess, Go)
- Robotics (navigation, manipulation)
- Finance (trading bots)
- Industrial automation (resource allocation)
Prerequisites
- Python: 3.7+
- Stable Baselines3: 1.7+
- Gymnasium: 0.28+ (or OpenAI Gym)
- PyTorch: 1.10+
Install Dependencies:
# Core packages
pip install stable-baselines3[extra] gymnasium torch
# Optional dependencies for specific environments
pip install box2d-py atari-py-modern mujoco
# For monitoring and visualization
pip install tensorboard wandb
Tips:
- Use virtual environments to manage dependencies
- Check GPU support with
torch.cuda.is_available()
- Consider installing
sb3-contrib
for additional algorithms
Architecture & Flow Diagram
Agent
│
▼
Environment (Gym)
│
▼
Agent (Observation + Reward)
│
▼
SB3 Algorithm (Policy Update)
│
▼
Agent (New Policy)
Setup & Configuration
Folder Structure:
rl-demo/
├── train.py
└── requirements.txt
requirements.txt
stable-baselines3[extra]
gymnasium
torch
Code Example: Training PPO on CartPole
Complete example with callbacks, evaluation, and custom network architecture.
import gymnasium as gym
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
import torch as th
from torch import nn
# Custom neural network architecture
class CustomCNN(BaseFeaturesExtractor):
def __init__(self, observation_space, features_dim=128):
super().__init__(observation_space, features_dim)
n_input = observation_space.shape[0]
self.cnn = nn.Sequential(
nn.Linear(n_input, 64),
nn.ReLU(),
nn.Linear(64, features_dim),
nn.ReLU(),
)
def forward(self, observations):
return self.cnn(observations)
# Create vectorized environment
env = make_vec_env("CartPole-v1", n_envs=4)
eval_env = make_vec_env("CartPole-v1", n_envs=2)
# Create evaluation callback
eval_callback = EvalCallback(
eval_env,
best_model_save_path="./logs/best_model",
log_path="./logs/results",
eval_freq=1000,
deterministic=True,
render=False
)
# Initialize agent with custom policy
policy_kwargs = dict(
features_extractor_class=CustomCNN,
features_extractor_kwargs=dict(features_dim=128),
)
model = PPO(
"MlpPolicy",
env,
learning_rate=3e-4,
n_steps=2048,
batch_size=64,
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
clip_range=0.2,
policy_kwargs=policy_kwargs,
tensorboard_log="./logs/tensorboard/",
verbose=1
)
try:
# Train agent
model.learn(
total_timesteps=50000,
callback=eval_callback,
progress_bar=True
)
# Save model
model.save("ppo_cartpole_final")
# Test model
obs = env.reset()
for _ in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, rewards, dones, info = env.step(action)
env.render()
except Exception as e:
print(f"Training failed: {e}")
finally:
env.close()
Advanced Example: SAC with Custom Environments
from stable_baselines3 import SAC
from stable_baselines3.common.noise import NormalActionNoise
# Custom environment (e.g., robotics, trading)
class CustomEnv(gym.Env):
def __init__(self):
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,))
self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(8,))
def step(self, action):
# Implement environment dynamics
pass
def reset(self):
# Reset environment state
pass
# Initialize environment and agent
env = CustomEnv()
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(
mean=np.zeros(n_actions),
sigma=0.1 * np.ones(n_actions)
)
model = SAC(
"MlpPolicy",
env,
action_noise=action_noise,
verbose=1,
tensorboard_log="./sac_logs/"
)
Interactive Demo Links
Common Pitfalls & Debugging Tips
Issue | Solution |
---|---|
Env not found | Install correct Gym/Gymnasium version and check env name |
Poor performance | Tune hyperparameters, increase training steps |
CUDA errors | Ensure PyTorch and CUDA versions match |
Memory issues | Reduce batch size or number of environments |
Unstable training | Adjust learning rate and clip range |
Exploration problems | Tune action noise and entropy coefficient |
Non-converging policy | Check reward function and observation space |
Slow training | Use vectorized environments and GPU acceleration |
Hyperparameter Optimization Example:
from stable_baselines3.common.callbacks import ProgressBarCallback
from optuna import Trial, create_study
def objective(trial: Trial) -> float:
return train_agent(
learning_rate=trial.suggest_float("lr", 1e-5, 1e-3, log=True),
batch_size=trial.suggest_int("batch_size", 32, 256),
n_steps=trial.suggest_int("n_steps", 1024, 4096)
)
study = create_study(direction="maximize")
study.optimize(objective, n_trials=50, callbacks=[ProgressBarCallback()])
Scaling & Optimization
Advanced techniques for scaling RL training and deployment.
from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor
from stable_baselines3.common.utils import set_random_seed
def make_env(env_id, rank, seed=0):
def _init():
env = gym.make(env_id)
env.seed(seed + rank)
return env
set_random_seed(seed)
return _init
# Parallel environments
num_envs = 8
env = SubprocVecEnv([make_env("CartPole-v1", i) for i in range(num_envs)])
env = VecMonitor(env, "logs/monitor")
# Enable GPU training
device = "cuda" if th.cuda.is_available() else "cpu"
model = PPO("MlpPolicy", env, device=device)
Tips:
- Use vectorized environments for parallel training
- Enable GPU acceleration for faster learning
- Tune hyperparameters systematically
- Save and resume training checkpoints
- Monitor with TensorBoard or Weights & Biases
- Use frame stacking for visual observations
- Implement reward shaping for complex tasks
Security & Privacy Notes
Best practices for secure RL deployment.
Aspect | Consideration |
---|---|
Model Storage | Encrypt saved policies and parameters |
Environment Data | Sanitize sensitive information in observations |
API Access | Implement authentication for deployed agents |
Safety Constraints | Add action bounds and safety checks |
Monitoring | Log agent behavior and anomalies |
Testing | Validate policies in sandboxed environments |
Updates | Version control for models and environments |
Hardware | Secure physical access for robotics applications |
Further Reading & Resources
- Stable Baselines3 Docs
- RL Baselines3 Zoo
- OpenAI Spinning Up in Deep RL
- RL Course by David Silver
- SB3 Contrib — Additional algorithms
- Gymnasium Docs — Environment creation
- CleanRL — Single-file implementations
- RL Papers with Code