feat(scalarization): Add DWA#731
Open
ppraneth wants to merge 9 commits into
Open
Conversation
ValerianRey
requested changes
Jun 11, 2026
ValerianRey
left a comment
Member
There was a problem hiding this comment.
Very good! Just a few nitpicks and some minor problems to the documentation.
Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com>
Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com>
Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com>
Contributor
Author
|
@ValerianRey I have update the doc strings |
ValerianRey
approved these changes
Jun 11, 2026
Member
|
Tyvm! LGTM. @PierreQuinton are you ok with merging this? |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds
DWA, the Dynamic Weight Average scalarizer from End-to-End Multi-Task Learning with Attention (CVPR 2019). It weights each task by how fast its loss has been going down compared to the others.Working
At epoch$t$ , the values are combined as:
2.0.The weights only need past loss values, no gradients (the paper notes this is why it's simpler than GradNorm), so it fits the
Scalarizerinterface.Usage
The weights at epoch$t$ depend on the average losses of epochs $t-1$ and $t-2$ . The scalarizer can't tell on its own when an epoch ends, so the user calls it every batch and calls a
step()method once at the end of each epoch:forwardrecords each batch's losses, andstep()finalizes the just-finished epoch's average loss and rolls the history forward (dropping the one from two epochs ago). This matches how the paper and LibMTL use it (per-epoch). During the first two epochs (before there are two averages) the weights are uniform.Design notes
DWA(temperature=2.0),temperaturemust be> 0.reset()clears it.shapeargument is needed (unlikeUW/IMTLL) the buffer is created lazily from the inputs.(0, +∞). So the losses are expected to keep a consistent, nonzero sign across epochs, they need not be positive, and positivity is not enforced.Tests
tests/unit/scalarization/test_dwa.pycovers: the uniform/bootstrap behavior for the first two epochs, the exact weight formula, that the per-epoch average is used (not just the last batch), thatstep()drops the oldest epoch, that the weights sum to the number of values, scalar output and gradient flow over all input shapes (including the computed-weights path after two epochs), support for consistently-signed negative losses (the ratio of same-sign losses stays positive),reset(),step()being a no-op with no data, that there are no learnable parameters, shape-change errors (within and between epochs), temperature validation, and the representations.