Skip to content

feat(aggregation): Add SDMGradWeighting#728

Merged
ValerianRey merged 5 commits into
SimplexLab:mainfrom
KhusPatel4450:feat/sdmgrad-weighting
Jun 11, 2026
Merged

feat(aggregation): Add SDMGradWeighting#728
ValerianRey merged 5 commits into
SimplexLab:mainfrom
KhusPatel4450:feat/sdmgrad-weighting

Conversation

@KhusPatel4450

Copy link
Copy Markdown
Contributor

Adds SDMGradWeighting from Direction-oriented Multi-objective Learning: Simple and Provable Stochastic Algorithms (NeurIPS 2023).

It mirrors MoDoWeighting's structure: the user computes the cross-batch matrix A = J_1 @ J_2.T from two independent mini-batches (via autojac.jac) and passes it to the weighting. The weighting runs the inner simplex-projected solve (matching the official OptMN-Lab/LibMTL momentum-SGD inner loop), tracks w across calls, and returns the direction-augmented weights so the parameter update is the usual losses.backward(weights).

Two points worth a look:

  1. (1+λ) normalization — the returned weights are (w_S + λ·w̃)/(1+λ) (sum to 1), matching the official implementation and LibMTL (g = (gw + λ·g0)/(1+λ)).
  2. Defaultslr=10, momentum=0.5, n_iter=20 follow the official OptMN-Lab class; lamda=0.3 follows the official run.sh experiments and LibMTL (their class default 0.6 is overridden to 0.3 in their own experiments).

Includes unit tests, docs, a NOTICES entry, and a CHANGELOG entry.

@KhusPatel4450 KhusPatel4450 added package: aggregation cc: feat Conventional commit type for new features. labels Jun 10, 2026

@ValerianRey ValerianRey left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good already!

My biggest concern is to understand where the official implementation differs from the paper, and to make sure that we're implementing the right version.

As far as I understand, here are the differences between the paper and the official implementation:

  • Aliasing of the jacobians in the official implementation, as you reported on discord, meaning that only the third jacobian is actually used. => Clearly a bug that we don't want to reproduce, but we want to report. This bug doesn't seem to be in LibMTL though.
  • Division by scale in both the official impl and in libmtl:
    In the official implementation:
scale = torch.mean(torch.sqrt(torch.diag(GG) + 1e-4))
GG = GG / scale.pow(2)

In LibMTL:

GG_diag = torch.diag(GG)
GG_diag = torch.where(GG_diag < 0, torch.zeros_like(GG_diag), GG_diag)
scale = torch.mean(torch.sqrt(GG_diag))
GG = GG / (scale.pow(2) + 1e-8)

I don't see that in the paper's algorithm (but maybe it's in the appendices or somewhere else). Should we do that too? Is it actually a form of normalization that could be implemented elsewhere?

  • The final update in the official impl (and in LibMTL) is:
g0 = torch.mean(zeta_grads, dim=1)
gw = torch.sum(zeta_grads * w, dim=1)
g = (gw + lamda * g0) / (1 + lamda)

I don't understand why they divide by (1 + lambda). Should we do that too? It seems you did, but we need to understand why. EDIT: I just read your PR message saying that this is to make the weights sum to 1. It makes sense. Maybe we can just keep it this way then! Also, this is just equivalent to a constant (unless lambda changes) LR factor, so it doesn't matter much and it's better to be equivalent to existing implementations there.

Comment thread src/torchjd/aggregation/_sdmgrad.py Outdated
Comment thread src/torchjd/aggregation/_sdmgrad.py Outdated
Comment thread src/torchjd/aggregation/_sdmgrad.py Outdated
Comment thread src/torchjd/aggregation/_sdmgrad.py Outdated
Comment thread src/torchjd/aggregation/_sdmgrad.py Outdated
Comment thread tests/unit/aggregation/test_sdmgrad.py Outdated
Comment thread src/torchjd/aggregation/_sdmgrad.py Outdated
Comment thread tests/unit/aggregation/test_sdmgrad.py Outdated
Comment thread tests/unit/aggregation/test_sdmgrad.py Outdated
Comment thread tests/unit/aggregation/test_sdmgrad.py Outdated
@KhusPatel4450

Copy link
Copy Markdown
Contributor Author

Hello,

So this pretty much has all your comments addressed. The previous commit has most of the code changes, and then this new one was just to fix the testing error.

Moved _projection2simplex out of both MoDoWeighting and SDMGradWeighting into aggregation._utils.simplex so it's shared, and moved the known-values test to tests/unit/aggregation/_utils/test_simplex.py, this just makes things cleaner

@ValerianRey

ValerianRey commented Jun 10, 2026

Copy link
Copy Markdown
Member

Very cool, thanks for fixing everything. All my comments have been addressed, except that I still don't understand why they normalize with scale in LibMTL and in the official implementation (see point 2 of my main review message). Any idea about this?

@KhusPatel4450

Copy link
Copy Markdown
Contributor Author

regarding your comment, I think the main reason is that it keeps the inner loop numerically stable regardless of the gradient scale. So like, whether we should add it or not? I mean we can technically just add a note that users should normalize A themselves in certain situations, sometimes it might not be beneficial too, so leaving that choice up to the user makes the most sense to me

@ValerianRey

ValerianRey commented Jun 10, 2026

Copy link
Copy Markdown
Member

regarding your comment, I think the main reason is that it keeps the inner loop numerically stable regardless of the gradient scale. So like, whether we should add it or not? I mean we can technically just add a note that users should normalize A themselves in certain situations, sometimes it might not be beneficial too, so leaving that choice up to the user makes the most sense to me

Actually I asked Claude and it gave me a pretty good answer: this is actually mentioned in appendix 6.1, second paragraph. It also seems to pair well with the division by (1 + lambda) at the end. So I think we should actually add this in our implementation of SDMGrad. Should be easy to add, but the manually computed examples will probably need to be updated.

Source: https://claude.ai/share/a7ae2ca8-6952-4388-a81e-ec8f5e82d4cc

@KhusPatel4450

Copy link
Copy Markdown
Contributor Author

Hello after looking at the claude coversation, I added the scale normalization, it is there now in this new commit.

@ValerianRey ValerianRey merged commit 3e5b88c into SimplexLab:main Jun 11, 2026
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: feat Conventional commit type for new features. package: aggregation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants