Model-Based RL

May 30, 2025

Recall the difference between “offline” versus “online” reinforcement learning. Offline (or off-policy) RL methods—think DQN or SAC—reuse data from past policies (stored in a replay buffer) and can take multiple gradient steps on that data, but they don’t collect new experience as they learn.

On-policy methods like vanilla policy gradient or PPO, by contrast, gather fresh rollouts from the current policy at every update, which makes them simpler conceptually but less data-efficient. This trade-off—data reuse versus fresh experience—motivates turning to model-based RL, where we learn a model of the environment’s dynamics so we can generate extra data ourselves.

So what’s the big idea of model-based RL? We treat our model as a learned “simulator” that, given a state and action, predicts the next state (and often the reward).

In domains where physics is well understood (some games, rigid-body systems) you might know the true dynamics equations and only need to fit a few unknown parameters from data; in most real-world scenarios you have no analytic model, so you learn it end-to-end—often first compressing observations into a low-dimensional state representation, then training a neural network to predict state transitions (and sometimes a separate network to predict rewards).

Forward Dynamics Model

The first task in MBRL is to learn a forward dynamics model using a neural net with parameters θ\theta.

f^θ(s,a)E[ss,a]\hat{f}_{\theta}(s,a) \approx \bbe[s' \mid s, a]

Optionally, you can also learn a reward model r^ϕ(s,a)r(s,a)\hat{r}_{\phi}(s,a) \approx r(s,a) since this is also part of the data returned to us by the world.

Suppose we have some dataset D={(st,at,rt,st+1)}\cald = \{(s_{t}, a_t, r_t, s_{t+1})\} of transitions (let's say collection by a base policy π0\pi_0 for now). Then we train by minimizing supervised loss

θ=arg minθ1D(s,a,r,s)Dsf^θ(s,a)2\theta^* = \argmin_{\theta} \frac{1}{|\cald|} \sum_{(s,a,r,s') \in \cald} \norm{s' - \hat{f}_{\theta}(s,a)}^2

and

ϕ=arg minϕ1D(s,a,r,s)Drr^ϕ(s,a)2\phi^* = \argmin_{\phi} \frac{1}{|\cald|} \sum_{(s,a,r,s') \in \cald} |r - \hat{r}_{\phi}(s,a)|^2

Once we have f^θ\hat{f}_{\theta}, we want to plan some set of actions to take that maximize reward over the trajectory. We fix some finite horizon HH, which is essentially the number of actions we have to carry out to complete a task, which we can then plan over by solving for

AH:={a0,,aH1}=argmaxa0:H1t=0H1γtr^ϕ(st,at)\cala^H := \{a_0,\dots,a_{H-1}\}^* = \arg\max_{a_{0:H-1}} \sum_{t=0}^{H-1}\gamma^t\,\hat r_\phi\bigl(s_t,a_t\bigr)

where st+1=f^θ(st,at)s_{t+1}=\hat f_\theta(s_t,a_t) and s0=ss_0=s.

Then we can do the following gradient-based optimization algorithm.

  • Run some policy to collect D\cald
  • Learn f^θ\hat f_\theta according to some loss (like MSE above)
  • Differentiate through f^θ\hat f_\theta and r^ϕ\hat r_{\phi} to do gradient ascent on the discounted predicted future reward.

Here are some things we can observe about this algorithm.

  • Pro: Scalable to high dimensions
  • Pro: Works well especially in overparameterized regimes
  • Con: Requires a nice optimization landscape

Alternatively, we can take a sampling-based approach called the cross-entropy method or CEM (not to be confused with cross-entropy loss).

  • Initialize a Gaussian N(μ,Σ)\mathcal N(\mu,\Sigma) over AH\mathcal A^H,
  • Sample NN sequences {a0:H1(i)}\{a^{(i)}_{0:H-1}\}, evaluate each
J(i)=t=0H1γtr^ϕ(st(i),at(i))J^{(i)}=\sum_{t=0}^{H-1}\gamma^t\,\hat r_\phi(s_t^{(i)},a_t^{(i)})
  • Pick the top-KK “elite” sequences EE,
  • Refit
μ1KiEa(i),Σ1KiE(a(i)μ)(a(i)μ)\mu\leftarrow \frac1K\sum_{i\in E}a^{(i)},\quad \Sigma\leftarrow \frac1K\sum_{i\in E}(a^{(i)}-\mu)(a^{(i)}-\mu)^\top
  • Repeat for a few iterations.

The main advantages and disadvantages then are the following.

  • Pro: Highly parallelizable
  • Pro: Requires no gradient info
  • Con: Scales poorly to higher dimensions

Let's consider the following thought experiment. Suppose our task is to go to the right to climb a cliff and get as high as possible, but past the cliff there is a ledge.

Our current approach collects data π0\pi_0 kind of randomly, which means the observed data likely doesn't reflect the existence of the ledge. Thus the forward dynamics model trained on this initially collected data, and our method of sampling/backprop to choose optimal actions, will learn that "going right means going higher".

But then our agent overshoots and launches off the cliff! What went wrong? The crux is that the state dynamics collected by π0\pi_0 doesn't match that of πf\pi_f, the policy we learn following f^θ\hat f_\theta.

Fix: After choosing our horizon actions AH\cala^H, append the visited transitions to D\cald before reiterating.

Open-Loop vs Closed-Loop

The above example brings us to the distinction between open-loop and closed-loop models. The difference is that in open loop, our action horizons AH\cala^H don't take into account the actual state dynamics of the world. That is, at inference time our model is solely relying on the pretrained forward dynamics model to choose actions and not taking into account true next states.

In closed loop, our agent incorporates live state transitions in response to the action horizons taken at inference time. Thus if our dynamics model starts to deviate slightly from the true world, the agent can take this into account and re-plan to account for model errors.

This brings us to Model-Predictive Control (MPC).

  • Run base policy π0\pi_0 to collect D\cald
  • Learn f^θ\hat f_\theta and r^ϕ\hat r_\phi to minimize if^θ(si,ai)si2\sum_i \norm{\hat f_\theta(s_i, a_i) - s'_i}^2, ...
  • Use f^θ(s,)\hat f_\theta(s,-) and r^ϕ(s,)\hat r_\phi(s,-) to optimize AH\cala^H
  • Execute the first planned action aa, observe next state ss' and reward rr. Repeat the previous step for a full horizon HH
  • Append the transitions to D\cald and go back to Step 2

Now we can replan for model errors, but the downside is that this approach might be compute intensive for some applications. The other constrain is that typically our horizons HH need to be short, since longer horizons can become increasingly compute intensive and lead to higher variance results.

Model-Based Policy Optimization

So far, our planner isn't really a policy for our agent in the traditional sense. Can we construct a policy π\pi using our planner model? The first option is to distill a planners actions into a policy. That is, train a policy to match the actions taken by our planner.

This solves the issue of being compute intensive at inference time, but we are still constrained to short horizons. How might we solve longer horizon problems? Let's consider the following two ideas.

  1. Plan with a terminal value function
  2. Augment model-free RL methods with data from model

We'll focus on (2) first. What we want to do is augment data with model-simulated rollouts. We can potentially generate full trajectories from initial states, but we run into the issue of model inaccuracy for longer horizons.

Alternatively, we can generate partial trajectories from initial states, but this might not be a good reflection of desired behavior at later states.

Our key idea is that we will generate partial trajectories from all states in the data. Thus we get coverage of partial trajectories close to initial and final states, without having to rollout long horizons.

  • Collect data using current πϕ\pi_\phi, add to Denv\cald_{env}
  • Update f^θ(s,a)\hat f_\theta(s, a) using Denv\cald_{env}
  • Collect synthetic rollouts using πϕ\pi_\phi in model f^θ\hat f_\theta from states in Denv\cald_{env}; add to Dmodel\cald_{model}
  • Update policy π\pi (and critic QQ) using DmodelD_{model}. Repeat whole algorithm

Here's the low down on pros vs cons.

  • Pro: Immensely useful, far more data efficient if model is easy to learn
  • Pro: Model can be trained on data without reward labels (fully self-supervised)
  • Pro: Model is somewhat task-agnostic (can sometimes be transferred across rewards)
  • Con: Models don’t optimize for task performance
  • Con: Sometimes harder to learn than a policy
  • Con: Another thing to train, more hyperparameters, more compute intensive

There are also all kinds of alternative model-based approaches. Instead of modeling ss' given s,as,a, we could do something like

  • Inverse modeling where we model aa given s,ss,s'
  • Future-prediction without actions to predict ss' from ss
  • Video interpolation, modeling st+1:t+ns_{t+1:t+n} from st,st+ns_t, s_{t+n}

Conclusion

That's all for this post, although this is more of an overview of the landscape of methods within MBRL. There's been some especially interesting research recently within this subject, including the topic of our final project for CS 224R with is autonomous driving using the (https://arxiv.org/abs/2301.04104)[DreamerV3] architecture.

This architecture decomposes forward dynamics into several different predictors that encode real-world states into latent representations using transformer-encoders and transformer-decoders. Its referred to as a Recurrent State-Space Model (RSSM) and is used to solve a variety of complex tasks in diverse domains.

A 2024 paper (https://arxiv.org/abs/2402.16720)[Think2Drive] applies this to the problem of closed-loop RL-based autonomous driving, which is an especially useful application of MBRL since simulations are virtually a necessity for training driving agents due to the alternative of letting one learn to drive in the real world.

I hope to write another post in the near future detailing how our research extends this work, but I think this demonstrates the utility in the diversity of these methods.