You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
GradVacWeighting._phi_t, CRMOGMWeighting._lambda, MoDoWeighting._lambda, and SDMGradWeighting._w were plain Python attributes initialised to None.
Changed each to self.register_buffer(...) so that nn.Module.to(device="cuda") (or .to(dtype=...)) automatically moves the warm-started state tensor when the user migrates their model to GPU.
_NashMTLWeighting is unchanged — its state is NumPy/cvxpy arrays, not PyTorch tensors.
_state_key tuples are left as plain attributes (not tensors).
Existing forward() and reset() assignments (self._w = w, self._w = None) continue to work: PyTorch's __setattr__ routes assignments to registered buffer names through _buffers automatically.
Test plan
uv run pytest tests/unit -q — 3167 passed, 0 failed
uv run ty check on all five files — 0 new errors (7 pre-existing NashMTL/cvxpy errors unchanged)
github-actionsBot
changed the title
fix(aggregation): register state tensors as buffers so .to(device) moves them
fix: Register state tensors as buffers so .to(device) moves them
Jun 14, 2026
ValerianRey
changed the title
fix: Register state tensors as buffers so .to(device) moves them
fix: Register state tensors as buffers
Jun 15, 2026
ValerianRey
added
cc: refactor
Conventional commit type for any refactoring, not user-facing, and not typing or perf improvements
and removed
cc: fix
Conventional commit type for bug fixes of the actual library (changes to src).
labels
Jun 15, 2026
github-actionsBot
changed the title
fix: Register state tensors as buffers
refactor(aggregation): Register state tensors as buffers
Jun 15, 2026
Arguably I think we could go 1 step further and also register as buffers the tensors that can only be changed through the init or through a setter.
This way, calling .to(device) will actually always work, as far as I know.
The only question is: do we want to have such tensors in the state dict of the aggregator / scalarizer? I don't think so, because these are not necessarily stateful aggregators. So in my opinion we should use persistent=False when registering those kinds of buffers. @KhusPatel4450 wanna take that one too?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
cc: refactorConventional commit type for any refactoring, not user-facing, and not typing or perf improvementspackage: aggregation
3 participants
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
GradVacWeighting._phi_t,CRMOGMWeighting._lambda,MoDoWeighting._lambda, andSDMGradWeighting._wwere plain Python attributes initialised toNone.self.register_buffer(...)so thatnn.Module.to(device="cuda")(or.to(dtype=...)) automatically moves the warm-started state tensor when the user migrates their model to GPU._NashMTLWeightingis unchanged — its state is NumPy/cvxpy arrays, not PyTorch tensors._state_keytuples are left as plain attributes (not tensors).forward()andreset()assignments (self._w = w,self._w = None) continue to work: PyTorch's__setattr__routes assignments to registered buffer names through_buffersautomatically.Test plan
uv run pytest tests/unit -q— 3167 passed, 0 faileduv run ty checkon all five files — 0 new errors (7 pre-existing NashMTL/cvxpy errors unchanged)Claude Code was used to assist with this task