Skip to content

feat(scalarization): Add FAMO#737

Merged
ValerianRey merged 6 commits into
SimplexLab:mainfrom
ppraneth:scalarization-7
Jun 15, 2026
Merged

feat(scalarization): Add FAMO#737
ValerianRey merged 6 commits into
SimplexLab:mainfrom
ppraneth:scalarization-7

Conversation

@ppraneth

@ppraneth ppraneth commented Jun 13, 2026

Copy link
Copy Markdown
Contributor

Adds FAMO (Fast Adaptive Multitask Optimization) from FAMO: Fast Adaptive Multitask Optimization (NeurIPS 2023). It decreases all task losses at an approximately equal rate while using only the loss values, so it never needs the per-task gradients.

This implements the API we agreed on in the discussion.

Working

The values are combined as

$$c \sum_i z_i \log(\ell_i - b_i + \epsilon), \qquad z = \mathrm{softmax}(w), \qquad c = \left( \sum_i \frac{z_i}{\ell_i - b_i + \epsilon} \right)^{-1}$$

where $\ell_i$ is the $i$-th value, $b_i$ is its lower bound (min_losses), $w$ are the learnable task-weighting logits, $z$ are the task weights, $c$ is a detached normalization constant, and $\epsilon$ is a small constant for stability. Backpropagating this loss gives FAMO's balanced update direction for the model.

Design

  • The logits $w$ are a public nn.Parameter. The user creates a separate optimizer for them, so there is no internal optimizer. To match the paper this should be Adam with a weight decay equal to the paper's regularization coefficient.
  • forward detaches the task weights, so loss.backward() populates only the model gradients, never the logits' gradient.
  • After the model step, the user recomputes the losses on the same batch and calls update(new_losses). This sets the logits' gradient from the change in losses (the softmax vector-Jacobian product) but does not step them. The user's optimizer steps them afterwards. So update is to the logits what backward is to the model parameters.
  • min_losses is an argument defaulting to zeros (same idea as STCH's reference). It is the per-task lower bound subtracted before the log.
  • FAMO takes the log of $\ell_i - b_i$, so each value must stay strictly above its bound (with the default zeros, the values must be positive). This is documented but not enforced, and a test locks that nan propagates rather than being silently clamped.

Usage

model = Linear(3, 2)
scalarizer = FAMO(2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
weight_optimizer = torch.optim.Adam(scalarizer.parameters(), lr=0.025, weight_decay=0.001)

features = torch.randn(8, 3)
losses = model(features).pow(2).mean(dim=0)
loss = scalarizer(losses)
optimizer.zero_grad()
loss.backward()
optimizer.step()

# Recompute the losses on the same batch, after the model update.
new_losses = model(features).pow(2).mean(dim=0)
scalarizer.update(new_losses)  # Sets the gradient of the task-weighting logits.
weight_optimizer.step()

@ppraneth ppraneth requested review from a team, PierreQuinton and ValerianRey as code owners June 13, 2026 11:09
@ppraneth ppraneth added cc: feat Conventional commit type for new features. package: scalarization labels Jun 13, 2026
@github-actions github-actions Bot changed the title add FAMO feat(scalarization): Add FAMO Jun 13, 2026

@PierreQuinton PierreQuinton left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR!

Comment thread src/torchjd/scalarization/_famo.py
Comment thread src/torchjd/scalarization/_famo.py Outdated
Comment thread src/torchjd/scalarization/_famo.py
Comment thread src/torchjd/scalarization/_famo.py Outdated

@ValerianRey ValerianRey left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good job ty!

The good news is that the pseudo-code from the paper (Algorithm 2 in Appendix C), the official implementation and the LibMTL implementation are the same, except for the default value of gamma that is 1e-5 in the official implementation and 1e-3 in Algorithm 2 and in LibMTL.

It's also quite equivalent with Algorithm 1, up to the practical optimization of xi (adam in practice, more simple in Algorithm 1 i think). I think your implementation is also equivalent to all of the existing ones.

I didn't review the tests yet. Maybe it's better to discuss about my suggested changes first (especially incorporating the inner optimizer inside the scalarizer).

Comment thread src/torchjd/scalarization/_famo.py Outdated
Comment thread src/torchjd/scalarization/_famo.py Outdated
Comment thread src/torchjd/scalarization/_famo.py Outdated
Comment thread src/torchjd/scalarization/_famo.py
@ValerianRey ValerianRey self-requested a review June 15, 2026 14:14
@ValerianRey ValerianRey enabled auto-merge (squash) June 15, 2026 14:15
@ValerianRey ValerianRey merged commit ca50c7f into SimplexLab:main Jun 15, 2026
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: feat Conventional commit type for new features. package: scalarization

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants