Recall the difference between “offline” versus “online” reinforcement learning. Offline (or off-policy) RL methods—think DQN or SAC—reuse data from past policies (stored in a replay buffer) and can take multiple gradient steps on that data, but they don’t collect new experience as they learn.
On-policy methods like vanilla policy gradient or PPO, by contrast, gather fresh rollouts from the current policy at every update, which makes them simpler conceptually but less data-efficient. This trade-off—data reuse versus fresh experience—motivates turning to model-based RL, where we learn a model of the environment’s dynamics so we can generate extra data ourselves.
So what’s the big idea of model-based RL? We treat our model as a learned “simulator” that, given a state and action, predicts the next state (and often the reward).
In domains where physics is well understood (some games, rigid-body systems) you might know the true dynamics equations and only need to fit a few unknown parameters from data; in most real-world scenarios you have no analytic model, so you learn it end-to-end—often first compressing observations into a low-dimensional state representation, then training a neural network to predict state transitions (and sometimes a separate network to predict rewards).
Forward Dynamics Model
The first task in MBRL is to learn a forward dynamics model using a neural net with parameters .
Optionally, you can also learn a reward model since this is also part of the data returned to us by the world.
Suppose we have some dataset of transitions (let's say collection by a base policy for now). Then we train by minimizing supervised loss
and
Once we have , we want to plan some set of actions to take that maximize reward over the trajectory. We fix some finite horizon , which is essentially the number of actions we have to carry out to complete a task, which we can then plan over by solving for
where and .
Then we can do the following gradient-based optimization algorithm.
- Run some policy to collect
- Learn according to some loss (like MSE above)
- Differentiate through and to do gradient ascent on the discounted predicted future reward.
Here are some things we can observe about this algorithm.
- Pro: Scalable to high dimensions
- Pro: Works well especially in overparameterized regimes
- Con: Requires a nice optimization landscape
Alternatively, we can take a sampling-based approach called the cross-entropy method or CEM (not to be confused with cross-entropy loss).
- Initialize a Gaussian over ,
- Sample sequences , evaluate each
- Pick the top- “elite” sequences ,
- Refit
- Repeat for a few iterations.
The main advantages and disadvantages then are the following.
- Pro: Highly parallelizable
- Pro: Requires no gradient info
- Con: Scales poorly to higher dimensions
Let's consider the following thought experiment. Suppose our task is to go to the right to climb a cliff and get as high as possible, but past the cliff there is a ledge.
Our current approach collects data kind of randomly, which means the observed data likely doesn't reflect the existence of the ledge. Thus the forward dynamics model trained on this initially collected data, and our method of sampling/backprop to choose optimal actions, will learn that "going right means going higher".
But then our agent overshoots and launches off the cliff! What went wrong? The crux is that the state dynamics collected by doesn't match that of , the policy we learn following .
Fix: After choosing our horizon actions , append the visited transitions to before reiterating.
Open-Loop vs Closed-Loop
The above example brings us to the distinction between open-loop and closed-loop models. The difference is that in open loop, our action horizons don't take into account the actual state dynamics of the world. That is, at inference time our model is solely relying on the pretrained forward dynamics model to choose actions and not taking into account true next states.
In closed loop, our agent incorporates live state transitions in response to the action horizons taken at inference time. Thus if our dynamics model starts to deviate slightly from the true world, the agent can take this into account and re-plan to account for model errors.
This brings us to Model-Predictive Control (MPC).
- Run base policy to collect
- Learn and to minimize , ...
- Use and to optimize
- Execute the first planned action , observe next state and reward . Repeat the previous step for a full horizon
- Append the transitions to and go back to Step 2
Now we can replan for model errors, but the downside is that this approach might be compute intensive for some applications. The other constrain is that typically our horizons need to be short, since longer horizons can become increasingly compute intensive and lead to higher variance results.
Model-Based Policy Optimization
So far, our planner isn't really a policy for our agent in the traditional sense. Can we construct a policy using our planner model? The first option is to distill a planners actions into a policy. That is, train a policy to match the actions taken by our planner.
This solves the issue of being compute intensive at inference time, but we are still constrained to short horizons. How might we solve longer horizon problems? Let's consider the following two ideas.
- Plan with a terminal value function
- Augment model-free RL methods with data from model
We'll focus on (2) first. What we want to do is augment data with model-simulated rollouts. We can potentially generate full trajectories from initial states, but we run into the issue of model inaccuracy for longer horizons.
Alternatively, we can generate partial trajectories from initial states, but this might not be a good reflection of desired behavior at later states.
Our key idea is that we will generate partial trajectories from all states in the data. Thus we get coverage of partial trajectories close to initial and final states, without having to rollout long horizons.
- Collect data using current , add to
- Update using
- Collect synthetic rollouts using in model from states in ; add to
- Update policy (and critic ) using . Repeat whole algorithm
Here's the low down on pros vs cons.
- Pro: Immensely useful, far more data efficient if model is easy to learn
- Pro: Model can be trained on data without reward labels (fully self-supervised)
- Pro: Model is somewhat task-agnostic (can sometimes be transferred across rewards)
- Con: Models don’t optimize for task performance
- Con: Sometimes harder to learn than a policy
- Con: Another thing to train, more hyperparameters, more compute intensive
There are also all kinds of alternative model-based approaches. Instead of modeling given , we could do something like
- Inverse modeling where we model given
- Future-prediction without actions to predict from
- Video interpolation, modeling from
Conclusion
That's all for this post, although this is more of an overview of the landscape of methods within MBRL. There's been some especially interesting research recently within this subject, including the topic of our final project for CS 224R with is autonomous driving using the (https://arxiv.org/abs/2301.04104)[DreamerV3] architecture.
This architecture decomposes forward dynamics into several different predictors that encode real-world states into latent representations using transformer-encoders and transformer-decoders. Its referred to as a Recurrent State-Space Model (RSSM) and is used to solve a variety of complex tasks in diverse domains.
A 2024 paper (https://arxiv.org/abs/2402.16720)[Think2Drive] applies this to the problem of closed-loop RL-based autonomous driving, which is an especially useful application of MBRL since simulations are virtually a necessity for training driving agents due to the alternative of letting one learn to drive in the real world.
I hope to write another post in the near future detailing how our research extends this work, but I think this demonstrates the utility in the diversity of these methods.