Skip to content

feat(scalarization): Add DWA#731

Open
ppraneth wants to merge 9 commits into
SimplexLab:mainfrom
ppraneth:scalarization-6
Open

feat(scalarization): Add DWA#731
ppraneth wants to merge 9 commits into
SimplexLab:mainfrom
ppraneth:scalarization-6

Conversation

@ppraneth

@ppraneth ppraneth commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

Adds DWA, the Dynamic Weight Average scalarizer from End-to-End Multi-Task Learning with Attention (CVPR 2019). It weights each task by how fast its loss has been going down compared to the others.

Working

At epoch $t$, the values are combined as:

$$\sum_k \lambda_k(t), L_k(t), \qquad \lambda_k(t) = \frac{K \exp(w_k(t-1) / T)}{\sum_i \exp(w_i(t-1) / T)}, \qquad w_k(t-1) = \frac{L_k(t-1)}{L_k(t-2)}$$

  • $w_k$ is the ratio of a task's average losses over the two previous epochs (a task whose loss dropped less gets more weight).
  • $T$ is the temperature (larger $T$ → more uniform weights). The paper uses 2.0.
  • $K$ is the number of values, so the weights sum to $K$.

The weights only need past loss values, no gradients (the paper notes this is why it's simpler than GradNorm), so it fits the Scalarizer interface.

Usage

The weights at epoch $t$ depend on the average losses of epochs $t-1$ and $t-2$. The scalarizer can't tell on its own when an epoch ends, so the user calls it every batch and calls a step() method once at the end of each epoch:

scalarizer = DWA()
for epoch in range(n_epochs):
    for batch in loader:
        losses = ...                 # one loss per task
        loss = scalarizer(losses)    # weighted sum, also records the batch's losses
        loss.backward()
        optimizer.step()
    scalarizer.step()                # roll the epoch history, once per epoch

forward records each batch's losses, and step() finalizes the just-finished epoch's average loss and rolls the history forward (dropping the one from two epochs ago). This matches how the paper and LibMTL use it (per-epoch). During the first two epochs (before there are two averages) the weights are uniform.

Design notes

  • DWA(temperature=2.0), temperature must be > 0.
  • its state is a non-trainable buffer. reset() clears it.
  • No shape argument is needed (unlike UW/IMTLL) the buffer is created lazily from the inputs.
  • The weights are detached, so gradients flow only through the current batch's losses.
  • It weights each value by the ratio of its losses over consecutive epochs, which the paper defines as a descending rate in the range (0, +∞). So the losses are expected to keep a consistent, nonzero sign across epochs, they need not be positive, and positivity is not enforced.

Tests

tests/unit/scalarization/test_dwa.py covers: the uniform/bootstrap behavior for the first two epochs, the exact weight formula, that the per-epoch average is used (not just the last batch), that step() drops the oldest epoch, that the weights sum to the number of values, scalar output and gradient flow over all input shapes (including the computed-weights path after two epochs), support for consistently-signed negative losses (the ratio of same-sign losses stays positive), reset(), step() being a no-op with no data, that there are no learnable parameters, shape-change errors (within and between epochs), temperature validation, and the representations.

ppraneth added 2 commits June 11, 2026 09:28
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
@ppraneth ppraneth requested review from a team, PierreQuinton and ValerianRey as code owners June 11, 2026 04:18
@ppraneth ppraneth added cc: feat Conventional commit type for new features. package: scalarization labels Jun 11, 2026
@github-actions github-actions Bot changed the title Scalarization 6 feat(scalarization): Scalarization 6 Jun 11, 2026
@ppraneth ppraneth changed the title feat(scalarization): Scalarization 6 feat(scalarization): Add DWA Jun 11, 2026

@ValerianRey ValerianRey left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good! Just a few nitpicks and some minor problems to the documentation.

Comment thread tests/unit/scalarization/test_dwa.py Outdated
Comment thread tests/unit/scalarization/test_dwa.py Outdated
Comment thread src/torchjd/scalarization/_dwa.py Outdated
Comment thread src/torchjd/scalarization/_dwa.py Outdated
Comment thread src/torchjd/scalarization/_dwa.py
ppraneth and others added 4 commits June 11, 2026 15:46
Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com>
Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com>
Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
@ppraneth ppraneth requested a review from ValerianRey June 11, 2026 10:31
@ppraneth

Copy link
Copy Markdown
Contributor Author

@ValerianRey I have update the doc strings

@ValerianRey

Copy link
Copy Markdown
Member

Tyvm! LGTM. @PierreQuinton are you ok with merging this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: feat Conventional commit type for new features. package: scalarization

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants