-
Notifications
You must be signed in to change notification settings - Fork 19
feat(aggregation): Add SDMGradWeighting #728
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
ValerianRey
merged 5 commits into
SimplexLab:main
from
KhusPatel4450:feat/sdmgrad-weighting
Jun 11, 2026
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
c409c4a
feat(aggregation): Add SDMGradWeighting
KhusPatel4450 f3c3d48
refactor(aggregation): address PR review comments on SDMGradWeighting
KhusPatel4450 db86b53
fix(aggregation): use eye_ helper to respect DTYPE in test_two_consec…
KhusPatel4450 21a8c53
Merge branch 'main' into feat/sdmgrad-weighting
ValerianRey 6f1c9d2
feat(aggregation): add scale normalization to SDMGradWeighting
KhusPatel4450 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,5 +41,6 @@ Abstract base classes | |
| nash_mtl.rst | ||
| pcgrad.rst | ||
| random.rst | ||
| sdmgrad.rst | ||
| sum.rst | ||
| trimmed_mean.rst | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| :hide-toc: | ||
|
|
||
| SDMGrad | ||
| ======= | ||
|
|
||
| .. autoclass:: torchjd.aggregation.SDMGradWeighting | ||
| :members: __call__, reset |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,214 @@ | ||
| # Partly adapted from https://github.com/OptMN-Lab/SDMGrad — MIT License, Copyright (c) 2023 ml-opt-lab. | ||
| # See NOTICES for the full license text. | ||
| from __future__ import annotations | ||
|
|
||
| from typing import cast | ||
|
|
||
| import torch | ||
| from torch import Tensor | ||
|
|
||
| from torchjd._mixins import Stateful | ||
| from torchjd.aggregation._mixins import _NonDifferentiable | ||
| from torchjd.linalg import Matrix | ||
|
|
||
| from ._utils.simplex import _projection2simplex | ||
| from ._weighting_bases import _MatrixWeighting | ||
|
|
||
|
|
||
| class SDMGradWeighting(_MatrixWeighting, Stateful, _NonDifferentiable): | ||
| r""" | ||
| :class:`~torchjd.Stateful` | ||
| :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] from `Direction-oriented | ||
| Multi-objective Learning: Simple and Provable Stochastic Algorithms | ||
| <https://arxiv.org/pdf/2305.18409>`_ (NeurIPS 2023). | ||
|
|
||
| .. warning:: | ||
| The input matrix must be :math:`A = J_1 J_2^\top`, computed from two **independent** | ||
| mini-batches via :func:`torchjd.autojac.jac`. It is **not** a Gramian and is not symmetric | ||
| or positive semi-definite in general. See the usage examples below. | ||
|
|
||
| :param lr: Learning rate of the inner SGD that solves for the task weights. Must be positive. | ||
| :param momentum: Momentum of the inner SGD. Must be in :math:`[0, 1)`. | ||
| :param n_iter: Number of inner SGD iterations performed at each call. Must be positive. | ||
| :param lambda_: Non-negative coefficient controlling how strongly the descent direction is pulled | ||
| toward the preference direction. Must be non-negative. | ||
| :param pref_vector: The preference vector :math:`\tilde w` defining the target direction. If not | ||
| provided, defaults to the uniform vector :math:`[1/m, \ldots, 1/m]` (i.e. the target diection is the average gradient). | ||
|
|
||
| .. note:: | ||
| The inner simplex-projected solver is adapted from the `official implementation | ||
| <https://github.com/OptMN-Lab/SDMGrad/blob/main/methods/weight_methods.py>`_. Note that the | ||
| official class default for this coefficient is ``0.6``, overridden to ``0.3`` in their own | ||
| experiments, which is the value used here (and in `LibMTL | ||
| <https://github.com/median-research-group/LibMTL/blob/main/LibMTL/weighting/SDMGrad.py>`_). | ||
|
|
||
| Before the inner solve, the input matrix is scale-normalized by the mean of the square | ||
| roots of its non-negative diagonal entries (following both the official implementation and | ||
| LibMTL). This makes the inner SGD learning rate scale-invariant with respect to gradient | ||
| magnitude. The normalization is briefly described in section 6.1 of the paper. | ||
|
|
||
| .. admonition:: Example (three batches per step) | ||
|
|
||
| The following example shows how to train with the SDMGrad algorithm. | ||
|
|
||
| .. testcode:: | ||
|
|
||
| import torch | ||
| from torch.nn import Linear, MSELoss, ReLU, Sequential | ||
| from torch.optim import SGD | ||
|
|
||
| from torchjd.aggregation import SDMGradWeighting | ||
| from torchjd.autojac import jac | ||
|
|
||
| # Generate data (9 batches of 16 examples of dim 5) for the sake of the example. | ||
| inputs = torch.randn(9, 16, 5) | ||
| targets = torch.randn(9, 16) | ||
|
|
||
| model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) | ||
| optimizer = SGD(model.parameters()) | ||
| criterion = MSELoss(reduction="none") | ||
| weighting = SDMGradWeighting(lambda_=0.3) | ||
| params = list(model.parameters()) | ||
|
|
||
| # Consume three consecutive (independent) batches per step. | ||
| for i in range(len(inputs) // 3): | ||
| # Batches corresponding to ξ, ξ' and ζ in the paper's algorithm. | ||
| input_1, input_2, input_3 = inputs[3 * i], inputs[3 * i + 1], inputs[3 * i + 2] | ||
| target_1, target_2, target_3 = targets[3 * i], targets[3 * i + 1], targets[3 * i + 2] | ||
|
|
||
| losses_1 = criterion(model(input_1).squeeze(dim=1), target_1) | ||
| jacs_1 = jac(losses_1, params) | ||
| J_1 = torch.cat([j.flatten(1) for j in jacs_1], dim=1) | ||
|
|
||
| losses_2 = criterion(model(input_2).squeeze(dim=1), target_2) | ||
| jacs_2 = jac(losses_2, params) | ||
| J_2 = torch.cat([j.flatten(1) for j in jacs_2], dim=1) | ||
|
|
||
| A = J_1 @ J_2.T | ||
| weights = weighting(A) | ||
|
|
||
| losses_3 = criterion(model(input_3).squeeze(dim=1), target_3) | ||
| losses_3.backward(weights) | ||
| optimizer.step() | ||
| optimizer.zero_grad() | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| lr: float = 10.0, | ||
| momentum: float = 0.5, | ||
| n_iter: int = 20, | ||
| lambda_: float = 0.3, | ||
| pref_vector: Tensor | None = None, | ||
| ) -> None: | ||
| super().__init__() | ||
| self.lr = lr | ||
| self.momentum = momentum | ||
| self.n_iter = n_iter | ||
| self.lambda_ = lambda_ | ||
| self.pref_vector = pref_vector | ||
| self._w: Tensor | None = None | ||
| self._state_key: tuple[int, torch.dtype, torch.device] | None = None | ||
|
|
||
| @property | ||
| def lr(self) -> float: | ||
| return self._lr | ||
|
|
||
| @lr.setter | ||
| def lr(self, value: float) -> None: | ||
| if value <= 0.0: | ||
| raise ValueError(f"Attribute `lr` must be positive. Found lr={value!r}.") | ||
| self._lr = value | ||
|
|
||
| @property | ||
| def momentum(self) -> float: | ||
| return self._momentum | ||
|
|
||
| @momentum.setter | ||
| def momentum(self, value: float) -> None: | ||
| if not (0.0 <= value < 1.0): | ||
| raise ValueError(f"Attribute `momentum` must be in [0, 1). Found momentum={value!r}.") | ||
| self._momentum = value | ||
|
|
||
| @property | ||
| def n_iter(self) -> int: | ||
| return self._n_iter | ||
|
|
||
| @n_iter.setter | ||
| def n_iter(self, value: int) -> None: | ||
| if value < 1: | ||
| raise ValueError( | ||
| f"Attribute `n_iter` must be a positive integer. Found n_iter={value!r}." | ||
| ) | ||
| self._n_iter = value | ||
|
|
||
| @property | ||
| def lambda_(self) -> float: | ||
| return self._lambda | ||
|
|
||
| @lambda_.setter | ||
| def lambda_(self, value: float) -> None: | ||
| if value < 0.0: | ||
| raise ValueError(f"Attribute `lambda_` must be non-negative. Found lambda_={value!r}.") | ||
| self._lambda = value | ||
|
|
||
| @property | ||
| def pref_vector(self) -> Tensor | None: | ||
| return self._pref_vector | ||
|
|
||
| @pref_vector.setter | ||
| def pref_vector(self, value: Tensor | None) -> None: | ||
| if value is not None and value.ndim != 1: | ||
| raise ValueError( | ||
| "Parameter `pref_vector` must be a vector (1D Tensor). Found `pref_vector.ndim = " | ||
| f"{value.ndim}`." | ||
| ) | ||
| self._pref_vector = value | ||
|
|
||
| def reset(self) -> None: | ||
| """Clears the stored task weights so the next forward starts from uniform.""" | ||
|
|
||
| self._w = None | ||
| self._state_key = None | ||
|
|
||
| def forward(self, matrix: Matrix, /) -> Tensor: | ||
| self._ensure_state(matrix) | ||
| w = cast(Tensor, self._w) | ||
| w_tilde = self._resolve_w_tilde(matrix) | ||
|
|
||
| diag = torch.diag(matrix).clamp(min=0.0) | ||
| scale = diag.sqrt().mean() | ||
| a = matrix / (scale.pow(2) + 1e-8) | ||
|
|
||
| velocity: Tensor | None = None | ||
| for _ in range(self._n_iter): | ||
| grad = a @ (w + self._lambda * w_tilde) | ||
| velocity = grad if velocity is None else self._momentum * velocity + grad | ||
| w = _projection2simplex(w - self._lr * velocity) | ||
|
|
||
| self._w = w | ||
| return (w + self._lambda * w_tilde) / (1.0 + self._lambda) | ||
|
|
||
| def _resolve_w_tilde(self, matrix: Matrix) -> Tensor: | ||
| m = matrix.shape[0] | ||
| if self._pref_vector is None: | ||
| return matrix.new_full((m,), 1.0 / m) | ||
| if self._pref_vector.shape[0] != m: | ||
| raise ValueError( | ||
| "The length of `pref_vector` must match the number of rows of the input matrix. " | ||
| f"Found len(pref_vector)={self._pref_vector.shape[0]} and matrix.shape[0]={m}." | ||
| ) | ||
| return self._pref_vector.to(dtype=matrix.dtype, device=matrix.device) | ||
|
|
||
| def _ensure_state(self, matrix: Matrix) -> None: | ||
| key = (matrix.shape[0], matrix.dtype, matrix.device) | ||
| if self._state_key == key and self._w is not None: | ||
| return | ||
| self._w = matrix.new_full((matrix.shape[0],), 1.0 / matrix.shape[0]) | ||
| self._state_key = key | ||
|
|
||
| def __repr__(self) -> str: | ||
| return ( | ||
| f"{self.__class__.__name__}(lr={self.lr!r}, momentum={self.momentum!r}, " | ||
| f"n_iter={self.n_iter!r}, lambda_={self.lambda_!r}, pref_vector={self.pref_vector!r})" | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| import torch | ||
| from torch import Tensor | ||
|
|
||
|
|
||
| def _projection2simplex(y: Tensor) -> Tensor: | ||
| """Euclidean projection of ``y`` onto the probability simplex.""" | ||
|
|
||
| m = len(y) | ||
| sorted_y = torch.sort(y, descending=True)[0] | ||
| tmpsum = y.new_zeros(()) | ||
| tmax_f = (torch.sum(y) - 1.0) / m | ||
| for i in range(m - 1): | ||
| tmpsum = tmpsum + sorted_y[i] | ||
| tmax = (tmpsum - 1.0) / (i + 1.0) | ||
| if tmax > sorted_y[i + 1]: | ||
| tmax_f = tmax | ||
| break | ||
| return torch.max(y - tmax_f, y.new_zeros(m)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| from torch.testing import assert_close | ||
| from utils.tensors import tensor_ | ||
|
|
||
| from torchjd.aggregation._utils.simplex import _projection2simplex | ||
|
|
||
|
|
||
| def test_projection2simplex_known_values() -> None: | ||
| """The simplex projection matches hand-computed Euclidean projections.""" | ||
|
|
||
| # Already-positive input: the deficit (1 - sum) is spread equally, no clamping. | ||
| assert_close( | ||
| _projection2simplex(tensor_([0.5, 0.1, 0.1])), | ||
| tensor_([0.6, 0.2, 0.2]), | ||
| ) | ||
| # Input with a negative entry: it gets clamped to zero. | ||
| assert_close( | ||
| _projection2simplex(tensor_([1.0, 0.0, -0.5])), | ||
| tensor_([1.0, 0.0, 0.0]), | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.