feat(scalarization): Add FAMO#737
Conversation
PierreQuinton
left a comment
There was a problem hiding this comment.
Thanks a lot for the PR!
ValerianRey
left a comment
There was a problem hiding this comment.
Very good job ty!
The good news is that the pseudo-code from the paper (Algorithm 2 in Appendix C), the official implementation and the LibMTL implementation are the same, except for the default value of gamma that is 1e-5 in the official implementation and 1e-3 in Algorithm 2 and in LibMTL.
It's also quite equivalent with Algorithm 1, up to the practical optimization of xi (adam in practice, more simple in Algorithm 1 i think). I think your implementation is also equivalent to all of the existing ones.
I didn't review the tests yet. Maybe it's better to discuss about my suggested changes first (especially incorporating the inner optimizer inside the scalarizer).
Adds
FAMO(Fast Adaptive Multitask Optimization) from FAMO: Fast Adaptive Multitask Optimization (NeurIPS 2023). It decreases all task losses at an approximately equal rate while using only the loss values, so it never needs the per-task gradients.This implements the API we agreed on in the discussion.
Working
The values are combined as
where$\ell_i$ is the $i$ -th value, $b_i$ is its lower bound ($w$ are the learnable task-weighting logits, $z$ are the task weights, $c$ is a detached normalization constant, and $\epsilon$ is a small constant for stability. Backpropagating this loss gives FAMO's balanced update direction for the model.
min_losses),Design
nn.Parameter. The user creates a separate optimizer for them, so there is no internal optimizer. To match the paper this should beAdamwith a weight decay equal to the paper's regularization coefficient.forwarddetaches the task weights, soloss.backward()populates only the model gradients, never the logits' gradient.update(new_losses). This sets the logits' gradient from the change in losses (the softmax vector-Jacobian product) but does not step them. The user's optimizer steps them afterwards. Soupdateis to the logits whatbackwardis to the model parameters.min_lossesis an argument defaulting to zeros (same idea as STCH'sreference). It is the per-task lower bound subtracted before the log.nanpropagates rather than being silently clamped.Usage