Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ changelog does not include internal changes that do not affect the user.

### Added

- Added `SDMGradWeighting` from [Direction-oriented Multi-objective Learning: Simple and Provable Stochastic Algorithms](https://arxiv.org/pdf/2305.18409) (NeurIPS 2023). It is a stateful `Weighting` that solves for task weights via a simplex-projected inner loop on a cross-batch matrix `A = J_1 @ J_2.T` (computed from two independent mini-batches using `autojac.jac`), with a direction-oriented regularizer pulling the descent direction toward a preference direction.
- Added `IMTL-L` (the loss-balancing variant of Impartial Multi-Task Learning) from [Towards
Impartial Multi-Task Learning](https://openreview.net/pdf?id=IMPnRXEWpvr) (ICLR 2021), a stateful
`Scalarizer` that learns a per-task scale `s_i` and combines the values as
Expand Down
28 changes: 28 additions & 0 deletions NOTICES
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,31 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

-------------------------------------------------------------------------------

Project: SDMGrad
Source: https://github.com/OptMN-Lab/SDMGrad/blob/main/methods/weight_methods.py
Used in: src/torchjd/aggregation/_sdmgrad.py

MIT License

Copyright (c) 2023 ml-opt-lab

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
1 change: 1 addition & 0 deletions docs/source/docs/aggregation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,6 @@ Abstract base classes
nash_mtl.rst
pcgrad.rst
random.rst
sdmgrad.rst
sum.rst
trimmed_mean.rst
7 changes: 7 additions & 0 deletions docs/source/docs/aggregation/sdmgrad.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:hide-toc:

SDMGrad
=======

.. autoclass:: torchjd.aggregation.SDMGradWeighting
:members: __call__, reset
2 changes: 2 additions & 0 deletions src/torchjd/aggregation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from ._nash_mtl import NashMTL
from ._pcgrad import PCGrad, PCGradWeighting
from ._random import Random, RandomWeighting
from ._sdmgrad import SDMGradWeighting
from ._sum import Sum, SumWeighting
from ._trimmed_mean import TrimmedMean
from ._upgrad import UPGrad, UPGradWeighting
Expand Down Expand Up @@ -93,6 +94,7 @@
"PCGradWeighting",
"Random",
"RandomWeighting",
"SDMGradWeighting",
"Sum",
"SumWeighting",
"TrimmedMean",
Expand Down
19 changes: 2 additions & 17 deletions src/torchjd/aggregation/_modo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torchjd.aggregation._mixins import _NonDifferentiable
from torchjd.linalg import Matrix

from ._utils.simplex import _projection2simplex
from ._weighting_bases import _MatrixWeighting


Expand Down Expand Up @@ -166,27 +167,11 @@ def forward(self, matrix: Matrix, /) -> Tensor:
lambd = cast(Tensor, self._lambda)

grad = matrix @ lambd + self._rho * lambd
lambd = self._projection2simplex(lambd - self._gamma * grad)
lambd = _projection2simplex(lambd - self._gamma * grad)

self._lambda = lambd
return lambd

@staticmethod
def _projection2simplex(y: Tensor) -> Tensor:
"""Euclidean projection of ``y`` onto the probability simplex."""

m = len(y)
sorted_y = torch.sort(y, descending=True)[0]
tmpsum = y.new_zeros(())
tmax_f = (torch.sum(y) - 1.0) / m
for i in range(m - 1):
tmpsum = tmpsum + sorted_y[i]
tmax = (tmpsum - 1.0) / (i + 1.0)
if tmax > sorted_y[i + 1]:
tmax_f = tmax
break
return torch.max(y - tmax_f, y.new_zeros(m))

def _ensure_state(self, matrix: Matrix) -> None:
key = (matrix.shape[0], matrix.dtype, matrix.device)
if self._state_key == key and self._lambda is not None:
Expand Down
214 changes: 214 additions & 0 deletions src/torchjd/aggregation/_sdmgrad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# Partly adapted from https://github.com/OptMN-Lab/SDMGrad — MIT License, Copyright (c) 2023 ml-opt-lab.
# See NOTICES for the full license text.
from __future__ import annotations

from typing import cast

import torch
from torch import Tensor

from torchjd._mixins import Stateful
from torchjd.aggregation._mixins import _NonDifferentiable
from torchjd.linalg import Matrix

from ._utils.simplex import _projection2simplex
from ._weighting_bases import _MatrixWeighting


class SDMGradWeighting(_MatrixWeighting, Stateful, _NonDifferentiable):
r"""
:class:`~torchjd.Stateful`
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] from `Direction-oriented
Multi-objective Learning: Simple and Provable Stochastic Algorithms
<https://arxiv.org/pdf/2305.18409>`_ (NeurIPS 2023).

.. warning::
The input matrix must be :math:`A = J_1 J_2^\top`, computed from two **independent**
mini-batches via :func:`torchjd.autojac.jac`. It is **not** a Gramian and is not symmetric
or positive semi-definite in general. See the usage examples below.

:param lr: Learning rate of the inner SGD that solves for the task weights. Must be positive.
:param momentum: Momentum of the inner SGD. Must be in :math:`[0, 1)`.
:param n_iter: Number of inner SGD iterations performed at each call. Must be positive.
:param lambda_: Non-negative coefficient controlling how strongly the descent direction is pulled
toward the preference direction. Must be non-negative.
:param pref_vector: The preference vector :math:`\tilde w` defining the target direction. If not
provided, defaults to the uniform vector :math:`[1/m, \ldots, 1/m]` (i.e. the target diection is the average gradient).

.. note::
The inner simplex-projected solver is adapted from the `official implementation
<https://github.com/OptMN-Lab/SDMGrad/blob/main/methods/weight_methods.py>`_. Note that the
official class default for this coefficient is ``0.6``, overridden to ``0.3`` in their own
experiments, which is the value used here (and in `LibMTL
<https://github.com/median-research-group/LibMTL/blob/main/LibMTL/weighting/SDMGrad.py>`_).

Before the inner solve, the input matrix is scale-normalized by the mean of the square
roots of its non-negative diagonal entries (following both the official implementation and
LibMTL). This makes the inner SGD learning rate scale-invariant with respect to gradient
magnitude. The normalization is briefly described in section 6.1 of the paper.

.. admonition:: Example (three batches per step)

The following example shows how to train with the SDMGrad algorithm.

.. testcode::

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd.aggregation import SDMGradWeighting
from torchjd.autojac import jac

# Generate data (9 batches of 16 examples of dim 5) for the sake of the example.
inputs = torch.randn(9, 16, 5)
targets = torch.randn(9, 16)

model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1))
optimizer = SGD(model.parameters())
criterion = MSELoss(reduction="none")
weighting = SDMGradWeighting(lambda_=0.3)
params = list(model.parameters())

# Consume three consecutive (independent) batches per step.
for i in range(len(inputs) // 3):
# Batches corresponding to ξ, ξ' and ζ in the paper's algorithm.
input_1, input_2, input_3 = inputs[3 * i], inputs[3 * i + 1], inputs[3 * i + 2]
target_1, target_2, target_3 = targets[3 * i], targets[3 * i + 1], targets[3 * i + 2]
Comment thread
ValerianRey marked this conversation as resolved.

losses_1 = criterion(model(input_1).squeeze(dim=1), target_1)
jacs_1 = jac(losses_1, params)
J_1 = torch.cat([j.flatten(1) for j in jacs_1], dim=1)

losses_2 = criterion(model(input_2).squeeze(dim=1), target_2)
jacs_2 = jac(losses_2, params)
J_2 = torch.cat([j.flatten(1) for j in jacs_2], dim=1)

A = J_1 @ J_2.T
weights = weighting(A)

losses_3 = criterion(model(input_3).squeeze(dim=1), target_3)
losses_3.backward(weights)
optimizer.step()
optimizer.zero_grad()
"""

def __init__(
self,
lr: float = 10.0,
momentum: float = 0.5,
n_iter: int = 20,
lambda_: float = 0.3,
pref_vector: Tensor | None = None,
) -> None:
super().__init__()
self.lr = lr
self.momentum = momentum
self.n_iter = n_iter
self.lambda_ = lambda_
self.pref_vector = pref_vector
self._w: Tensor | None = None
self._state_key: tuple[int, torch.dtype, torch.device] | None = None

@property
def lr(self) -> float:
return self._lr

@lr.setter
def lr(self, value: float) -> None:
if value <= 0.0:
raise ValueError(f"Attribute `lr` must be positive. Found lr={value!r}.")
self._lr = value

@property
def momentum(self) -> float:
return self._momentum

@momentum.setter
def momentum(self, value: float) -> None:
if not (0.0 <= value < 1.0):
raise ValueError(f"Attribute `momentum` must be in [0, 1). Found momentum={value!r}.")
self._momentum = value

@property
def n_iter(self) -> int:
return self._n_iter

@n_iter.setter
def n_iter(self, value: int) -> None:
if value < 1:
raise ValueError(
f"Attribute `n_iter` must be a positive integer. Found n_iter={value!r}."
)
self._n_iter = value

@property
def lambda_(self) -> float:
return self._lambda

@lambda_.setter
def lambda_(self, value: float) -> None:
if value < 0.0:
raise ValueError(f"Attribute `lambda_` must be non-negative. Found lambda_={value!r}.")
self._lambda = value

@property
def pref_vector(self) -> Tensor | None:
return self._pref_vector

@pref_vector.setter
def pref_vector(self, value: Tensor | None) -> None:
if value is not None and value.ndim != 1:
raise ValueError(
"Parameter `pref_vector` must be a vector (1D Tensor). Found `pref_vector.ndim = "
f"{value.ndim}`."
)
self._pref_vector = value

def reset(self) -> None:
"""Clears the stored task weights so the next forward starts from uniform."""

self._w = None
self._state_key = None

def forward(self, matrix: Matrix, /) -> Tensor:
self._ensure_state(matrix)
w = cast(Tensor, self._w)
w_tilde = self._resolve_w_tilde(matrix)

diag = torch.diag(matrix).clamp(min=0.0)
scale = diag.sqrt().mean()
a = matrix / (scale.pow(2) + 1e-8)

velocity: Tensor | None = None
for _ in range(self._n_iter):
grad = a @ (w + self._lambda * w_tilde)
velocity = grad if velocity is None else self._momentum * velocity + grad
w = _projection2simplex(w - self._lr * velocity)

self._w = w
return (w + self._lambda * w_tilde) / (1.0 + self._lambda)

def _resolve_w_tilde(self, matrix: Matrix) -> Tensor:
m = matrix.shape[0]
if self._pref_vector is None:
return matrix.new_full((m,), 1.0 / m)
if self._pref_vector.shape[0] != m:
raise ValueError(
"The length of `pref_vector` must match the number of rows of the input matrix. "
f"Found len(pref_vector)={self._pref_vector.shape[0]} and matrix.shape[0]={m}."
)
return self._pref_vector.to(dtype=matrix.dtype, device=matrix.device)

def _ensure_state(self, matrix: Matrix) -> None:
key = (matrix.shape[0], matrix.dtype, matrix.device)
if self._state_key == key and self._w is not None:
return
self._w = matrix.new_full((matrix.shape[0],), 1.0 / matrix.shape[0])
self._state_key = key

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(lr={self.lr!r}, momentum={self.momentum!r}, "
f"n_iter={self.n_iter!r}, lambda_={self.lambda_!r}, pref_vector={self.pref_vector!r})"
)
18 changes: 18 additions & 0 deletions src/torchjd/aggregation/_utils/simplex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
from torch import Tensor


def _projection2simplex(y: Tensor) -> Tensor:
"""Euclidean projection of ``y`` onto the probability simplex."""

m = len(y)
sorted_y = torch.sort(y, descending=True)[0]
tmpsum = y.new_zeros(())
tmax_f = (torch.sum(y) - 1.0) / m
for i in range(m - 1):
tmpsum = tmpsum + sorted_y[i]
tmax = (tmpsum - 1.0) / (i + 1.0)
if tmax > sorted_y[i + 1]:
tmax_f = tmax
break
return torch.max(y - tmax_f, y.new_zeros(m))
19 changes: 19 additions & 0 deletions tests/unit/aggregation/_utils/test_simplex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from torch.testing import assert_close
from utils.tensors import tensor_

from torchjd.aggregation._utils.simplex import _projection2simplex


def test_projection2simplex_known_values() -> None:
"""The simplex projection matches hand-computed Euclidean projections."""

# Already-positive input: the deficit (1 - sum) is spread equally, no clamping.
assert_close(
_projection2simplex(tensor_([0.5, 0.1, 0.1])),
tensor_([0.6, 0.2, 0.2]),
)
# Input with a negative entry: it gets clamped to zero.
assert_close(
_projection2simplex(tensor_([1.0, 0.0, -0.5])),
tensor_([1.0, 0.0, 0.0]),
)
15 changes: 0 additions & 15 deletions tests/unit/aggregation/test_modo.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,18 +170,3 @@ def test_non_symmetric_input() -> None:
assert_close(W(G), expected)
assert W(G).shape == (m,)
assert (W(G) >= 0).all()


def test_projection2simplex_known_values() -> None:
"""The simplex projection matches hand-computed Euclidean projections."""

# Already-positive input: the deficit (1 - sum) is spread equally, no clamping.
assert_close(
MoDoWeighting._projection2simplex(tensor_([0.5, 0.1, 0.1])),
tensor_([0.6, 0.2, 0.2]),
)
# Input with a negative entry: it gets clamped to zero.
assert_close(
MoDoWeighting._projection2simplex(tensor_([1.0, 0.0, -0.5])),
tensor_([1.0, 0.0, 0.0]),
)
Loading
Loading