Skip to content

refactor(aggregation): Register state tensors as buffers#739

Merged
ValerianRey merged 1 commit into
SimplexLab:mainfrom
KhusPatel4450:feat/register-state-buffers
Jun 15, 2026
Merged

refactor(aggregation): Register state tensors as buffers#739
ValerianRey merged 1 commit into
SimplexLab:mainfrom
KhusPatel4450:feat/register-state-buffers

Conversation

@KhusPatel4450

Copy link
Copy Markdown
Contributor

Summary

  • GradVacWeighting._phi_t, CRMOGMWeighting._lambda, MoDoWeighting._lambda, and SDMGradWeighting._w were plain Python attributes initialised to None.
  • Changed each to self.register_buffer(...) so that nn.Module.to(device="cuda") (or .to(dtype=...)) automatically moves the warm-started state tensor when the user migrates their model to GPU.
  • _NashMTLWeighting is unchanged — its state is NumPy/cvxpy arrays, not PyTorch tensors.
  • _state_key tuples are left as plain attributes (not tensors).
  • Existing forward() and reset() assignments (self._w = w, self._w = None) continue to work: PyTorch's __setattr__ routes assignments to registered buffer names through _buffers automatically.

Test plan

  • uv run pytest tests/unit -q — 3167 passed, 0 failed
  • uv run ty check on all five files — 0 new errors (7 pre-existing NashMTL/cvxpy errors unchanged)

Claude Code was used to assist with this task

@KhusPatel4450 KhusPatel4450 added the cc: fix Conventional commit type for bug fixes of the actual library (changes to src). label Jun 14, 2026
@KhusPatel4450 KhusPatel4450 self-assigned this Jun 14, 2026
@github-actions github-actions Bot changed the title fix(aggregation): register state tensors as buffers so .to(device) moves them fix: Register state tensors as buffers so .to(device) moves them Jun 14, 2026

@ValerianRey ValerianRey left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. @PierreQuinton This is related to what you asked in #737.

@PierreQuinton PierreQuinton left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool, I guess the only problem is that they are typed as Any? But this doesn't matter too much. LGTM

@ValerianRey

Copy link
Copy Markdown
Member

Very cool, I guess the only problem is that they are typed as Any? But this doesn't matter too much. LGTM

No, it's typed as None | Tensor, same as before. I'm not sure how ty infers that though, but it works.

@ValerianRey ValerianRey changed the title fix: Register state tensors as buffers so .to(device) moves them fix: Register state tensors as buffers Jun 15, 2026
@ValerianRey ValerianRey added cc: refactor Conventional commit type for any refactoring, not user-facing, and not typing or perf improvements and removed cc: fix Conventional commit type for bug fixes of the actual library (changes to src). labels Jun 15, 2026
@github-actions github-actions Bot changed the title fix: Register state tensors as buffers refactor(aggregation): Register state tensors as buffers Jun 15, 2026
@ValerianRey ValerianRey merged commit c17f009 into SimplexLab:main Jun 15, 2026
21 of 23 checks passed
@ValerianRey

Copy link
Copy Markdown
Member

Arguably I think we could go 1 step further and also register as buffers the tensors that can only be changed through the init or through a setter.

This way, calling .to(device) will actually always work, as far as I know.
The only question is: do we want to have such tensors in the state dict of the aggregator / scalarizer? I don't think so, because these are not necessarily stateful aggregators. So in my opinion we should use persistent=False when registering those kinds of buffers. @KhusPatel4450 wanna take that one too?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: refactor Conventional commit type for any refactoring, not user-facing, and not typing or perf improvements package: aggregation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants