Recall the notion of importance weights from the previous post. One drawback of this approach that I didn't discuss is that this method is only suitable when policies are very similar to one another.
If the policies are very different then the importance weights might blow up or vanish, giving us a worse estimate of the gradient of the loss function.
There is also the issue demonstrated in the example of the last post that policy gradients don't make efficient use of data! If a robot rewarded for forward velocity takes one small step forward and then falls backward, it pushes down the likelihood of a step forward.
For sparse rewards, actions with partially correct results aren't utilized by policy gradients since no actual reward is received.
In this post we'll discuss an alternative class of algorithms called Actor Critic Methods.
Value Functions
Let's revisit some useful functions related to Markov Decision Processes (MDPs).
The (on-policy) value functionVπ(s) - future expected rewards starting at s and following π.
Vπ(s)=Eτ∼pθ(τ)[r(τ)∣s1=s]
The (on-policy) action-value functionQπ(s,a) - future expected rewards starting at s, taking a, then following π.
Qπ(s,a)=Eτ∼pθ(τ)[r(τ)∣s1=s,a1=a]
One useful relation between these is that
Vπ(s)=Ea∼π(⋅∣s)[Qπ(s,a)]
i.e. if we choose the action in the Q-function according to its distribution given by the policy π, then we just end up with the value of the state. There is yet a third value function.
The (on-policy) advantage functionAπ(s,a) - how much better it is to take a than to follow π at state s.
Aπ(s,a)=Qπ(s,a)−Vπ(s)
Revisiting Policy Gradient
Recall our policy gradient had the (non-final) form
∇θL(θ)≈−N1i=1∑Nt=1∑T∇θlogπθ(at∣st)reward to go(t=t′∑Tr(st′,at′))
The term on the right is the estimate of future rewards if we take action at in state st. Can we get a better estimate?
t′=t∑TEπθ[r(st′,at′)∣st,at]=Qπθ(st,at)
can be seen as the true expected rewards to go. This would be way better!
Should we use baselines like before? Our average Q-value would look like
b=N1i=1∑NQπθ(st,at)
This looks awfully familiar to the expected value of the Q-value function across actions distributed according to our policy. This intuitively suggests a good baseline would actually be our value functionVπθ(st).
Recall Aπ(s,a)=Qπ(s,a)−Vπ(s) is the definition of our advantage function. Thus, naturally we should make our policy gradient
The key is that better estimates of A lead to less noisy gradients! How can we estimate the advantage function?
Estimating Value
Since advantage is a function of both value and state-action value, one might think we have to estimate both well to estimate advantage. However, there is actually a way to compute all three value functions as a function of V.
Qπ(st,at)Aπ(st,at)=t′=t∑TEπθ[r(st′,at′)∣st,at]=r(st,at)+t′=t+1∑TEπθ[r(st′,at′)∣st,at]=r(st,at)+Est+1∼p(⋅∣st,at)[Vπ(st+1)]≈r(st,at)+Vπ(st+1)(use the sampled st+1)≈r(st,at)+Vπ(st+1)−Vπ(st)
Thus we can actually just fit Vπ.
Version 1: Monte Carlo
Our original single sample estimate is
Vπ(st)≈r(τ)
where τ∼pθ(τ∣s1=s). We can aggregate this into a labeled dataset of states s along with their single sample estimates
{(st(i),r(τ(i)))}
where τ(i)∼pθ(τ∣st(i)) is one of the N rollouts. Then we do supervised learning to fit the estimated value function.
L(ϕ)=21i=1∑NV^ϕπθ(st(i))−r(τ(i))2
Version 2: Bootstrapping
Our Monte Carlo target is r(τ) and our ideal target is
Some nasty issues can arise though. In RL, off-policy learning + function approximation + bootstrapping is reffered to as the deadly triad.
Over-iterating on the same data without fresh on‐policy coverage breaks the contraction guarantees your Bellman‐style updates rely on. The result is classic RL divergence rather than convergence to a sensible value function.
Idea 1: Use KL divergence constraint on policy.
Es∼πθ′[DKL(πθ(⋅∣s)∣∣πθ′(⋅∣s))]≤δ
We see this in LLM preference optimization (RLHF).
Idea 2: Can we bound the importance weights? This doesn't directly constrain the policy, but removes incentives. This is the key idea being proximal policy optimization (PPO).
We will talk about these ideas later, but for now our solution will be replay buffers. The idea of the replay buffer is to store the recent history of transitions that we saw in prior timesteps. By a transition we mean state-action-reward-nextstate (SARS).
We then sample transitions from the buffer in mini batches and use this to estimate our state-value function and do gradient steps.
The overview of the broken algorithm is as follows.
Collect experience {si,ai} from πθ(a∣s). Add this to the replay buffer R.
Sample a minibatch {(si,ai,ri,si′)} from buffer R.
Update V^ϕπ(s) using targets yi=ri+γV^ϕπ(si′) for each si.
Evaluate A^π(si,ai)=ri+γV^ϕπ(si′)−V^ϕ(si).
∇θL(θ)≈−N1∑i∇θlogπθ(ai∣si)A^π(si,ai)
θ←θ−η∇θL(θ)
Steps 3 through 6 are carried out on the minibatches sampled from replay memory. Why is this algorithm broken? Note that in step 3 the TD target is not correct, and therefore the action taken in step 5 is not correct!
Why? Note that we are using trajectories from previous policies to fit the value function for our current policy. The actions taken and therefore the next state are not necessarily the same as the actions our policy would have taken.
What if we fit Q(s,a) instead of V(s)? Now we can take into account the action taken instead of assuming our action is distributed according to our current policy.
I will spare you the math and just say that our new targets are
yi=ri+γQ^ϕπθ(si′,a′iπ)
where a′iπ∼πθ(⋅∣si′). This is the same as our actual intended target as it provides an unbiased single-sample estimate of the state-value function via
V^ϕπθ(si)=Ea∼πθ(⋅∣si)[Q^ϕπθ(si,a)]
Similarly, in step 5 the given action ai is not the action our policy πθ would have taken. We can use the same trick, but this time with ai instead of ai′, i.e. we sample aiπ∼πθ(⋅∣si) and set
∇θL(θ)≈−N1i∑∇θlogπθ(aiπ∣si)A^π(si,ai)
Conclusion
This is a nice overview of actor-critic methods with efficient bootstrapping and sampling techinques that are used in modern RL. It is also a doorway to more standard algorithsm like PPO (proximal policy optimization) and RLHF for LLMs which I will discuss in subsequent posts.
There are some more interesting techniques for approximating the V and Q functions that I also plan to write about, which is referred to and popularly known as Deep Q-learning or Deep Q-networks (DQNs). I have a previous post but I intend to delve more into the math in line with this series of posts on topics within RL.