[back to home]

[2025-04-17] Why categorical regression stabilizes RL

It's well-known that classification-based objectives mog regression-based (i.e. MSE) objectives for training value estimators in RL, but why?

Although the classical distribution RL propoganda suggests that improvement stems from modeling the entire distribution of returns rather than just its expectation, Farebrother et al. 2024 find that computing a scalar Q-function target using only the model's expected values, construction a target distribution (a histogram of a Gaussian centered at that scalar target), and minimizing the KL w.r.t. that target outperforms distributional backups, despite the fact that this objective, like MSE, also essentially only propogates the expectation of the value back to previous states.

My headcanon is that it's more of a representational thing... more specifically, consider the standard one-step TD update: \[Q(S_t, A_t) \leftarrow R_t + \gamma \mathbb E_{a \sim \pi(\cdot \mid S_{t + 1})} [ Q(S_{t + 1}, a) ] \] which can be implemented (in the regression-based objective case) via gradient descent on the MSE loss \[\mathcal{L} = \left( Q(S_t, A_t) - R_t - \gamma \;\text{sg} \{\!\{ \mathbb E_{a \sim \pi(\cdot \mid S_{t + 1})} [ Q(S_{t + 1}, a) ] \}\!\} \right)^2\] (where \(\text{sg} \{\!\{\cdot\}\!\}\) is a stop-grad / detach operation).

RL theorists will know that raw TD learning kinda sucks in terms of being stable or actually converging, and most RL implementations resort to EMA target networks to compute Q-targets for stability.

Let's take a deeper look at what the gradients are up to for MSE regression...

In neural networkland the scalar output is usually computed by a final linear layer from the networks hidden dimension to a single size-one output. In other words, the output is computed by taking the penultimate activation vector, and evaluating its inner product with a learnable vector parameter.

For any particular input/output pair, the gradient accumulated to this learnable vector parameter is just some scalar multiple of the penultimate activation vector: a positive scalar if the prediction is an underestimate, a negative scalar if the current prediction is an overestimate.

In many RL environments, a plausibly common scenario is where the penultimate activation / feature vector between two adjacent states (let's denote them \(S_t\) and \(S_{t+1}\)) are relatively aligned, or even identical. In this case, if the TD error between these two states is nonzero, then this learnable feature vector will change by some multiple of the former state's penultimate activation vector.

For example, if the reward \(R_t\) after state-action \(S_t, A_t\) is positive, then the TD update will encourage \(Q(S_t, A_t)\) to increase. The learnable parameter vector will move in the direction of the penultimate activation vector, which ends up changing the Q-value estimates of BOTH states in the same direction. If the penultimate activations of \(S_{t + 1}\) is more aligned to the feature vector than that of \(S_t\), ends up worsening the overall TD error.

In this specific case, a full-gradient style objective \[\mathcal{L} = \left( Q(S_t, A_t) - R_t - \gamma \; \mathbb E_{a \sim \pi(\cdot \mid S_{t + 1})} [ Q(S_{t + 1}, a) ] \right)^2\] without the stop-grad would correctly identify that the parameter vector should actually move in the opposite direction (locally decreasing both Q-values and improving the TD error), although, of course, the stability and correctness of this objective comes at the cost of incredibly slow learning (again, see Schnell et al. 2025 ).

Using an EMA target network mildly alleviates the issues with the TD update by reducing the speed at which updates to the online estimate of \(Q(s_t, a_t)\) propogate to the target against which \(Q(s_t, a_t)\) is itself trained. Unfortunately, this also introduces a fundamental trade-off between the stability of learning (which worsens as the target network updates faster per optimization step) versus the maximum theoretical computational efficiency of learning (which improves as the target network updates faster per optimization step).

Do note that this can, in theory, be disentangled from the sample efficiency of learning (assuming you're using a stable learning objective, e.g. AWR or the Muesli policy loss as opposed to something unstable like policy gradient or PPO... I should probably make a post on this someday...), since you can just introduce high update-to-data ratios to combat the ill-conditioning of full-gradient learning, for example.

However, categorical regression introduces, at least in my opinion, a much more complete solution to this problem. As a refresher, this approach involves using a linear layer to map the penultimate activation vector to a vector of logits, which then evaluates to a discrete distribution over possible Q-values via softmax. The penultimate activation being more aligned with any particular row of the weight matrix of this linear layer implies a more confident belief that the Q-value for the input state is close to the constant scalar atom corresponding to that row.

With scalar regression, taking any particular update to the last parameter vector and scaling it up will almost always asymptotically push the predicted Q-value towards positive or negative infinity. However, under categorical regression, scaling any particular update to the last weight matrix simply increases the confidence of some particular Q-value, depending on the input, which naturally improves learning stability.

From my empirical experience, a target network still additionally stabilizes learning, but using categorical regression over MSE is still an overall Pareto improvement on the trade-off between speed and stability.