feat(aggregation): Add SDMGradWeighting#728
Conversation
There was a problem hiding this comment.
Looking good already!
My biggest concern is to understand where the official implementation differs from the paper, and to make sure that we're implementing the right version.
As far as I understand, here are the differences between the paper and the official implementation:
- Aliasing of the jacobians in the official implementation, as you reported on discord, meaning that only the third jacobian is actually used. => Clearly a bug that we don't want to reproduce, but we want to report. This bug doesn't seem to be in LibMTL though.
- Division by
scalein both the official impl and in libmtl:
In the official implementation:
scale = torch.mean(torch.sqrt(torch.diag(GG) + 1e-4))
GG = GG / scale.pow(2)In LibMTL:
GG_diag = torch.diag(GG)
GG_diag = torch.where(GG_diag < 0, torch.zeros_like(GG_diag), GG_diag)
scale = torch.mean(torch.sqrt(GG_diag))
GG = GG / (scale.pow(2) + 1e-8)I don't see that in the paper's algorithm (but maybe it's in the appendices or somewhere else). Should we do that too? Is it actually a form of normalization that could be implemented elsewhere?
- The final update in the official impl (and in LibMTL) is:
g0 = torch.mean(zeta_grads, dim=1)
gw = torch.sum(zeta_grads * w, dim=1)
g = (gw + lamda * g0) / (1 + lamda)I don't understand why they divide by (1 + lambda). Should we do that too? It seems you did, but we need to understand why. EDIT: I just read your PR message saying that this is to make the weights sum to 1. It makes sense. Maybe we can just keep it this way then! Also, this is just equivalent to a constant (unless lambda changes) LR factor, so it doesn't matter much and it's better to be equivalent to existing implementations there.
|
Hello, So this pretty much has all your comments addressed. The previous commit has most of the code changes, and then this new one was just to fix the testing error. Moved _projection2simplex out of both MoDoWeighting and SDMGradWeighting into aggregation._utils.simplex so it's shared, and moved the known-values test to tests/unit/aggregation/_utils/test_simplex.py, this just makes things cleaner |
|
Very cool, thanks for fixing everything. All my comments have been addressed, except that I still don't understand why they normalize with |
|
regarding your comment, I think the main reason is that it keeps the inner loop numerically stable regardless of the gradient scale. So like, whether we should add it or not? I mean we can technically just add a note that users should normalize A themselves in certain situations, sometimes it might not be beneficial too, so leaving that choice up to the user makes the most sense to me |
Actually I asked Claude and it gave me a pretty good answer: this is actually mentioned in appendix 6.1, second paragraph. It also seems to pair well with the division by (1 + lambda) at the end. So I think we should actually add this in our implementation of SDMGrad. Should be easy to add, but the manually computed examples will probably need to be updated. Source: https://claude.ai/share/a7ae2ca8-6952-4388-a81e-ec8f5e82d4cc |
|
Hello after looking at the claude coversation, I added the scale normalization, it is there now in this new commit. |
Adds
SDMGradWeightingfrom Direction-oriented Multi-objective Learning: Simple and Provable Stochastic Algorithms (NeurIPS 2023).It mirrors
MoDoWeighting's structure: the user computes the cross-batch matrixA = J_1 @ J_2.Tfrom two independent mini-batches (viaautojac.jac) and passes it to the weighting. The weighting runs the inner simplex-projected solve (matching the official OptMN-Lab/LibMTL momentum-SGD inner loop), trackswacross calls, and returns the direction-augmented weights so the parameter update is the usuallosses.backward(weights).Two points worth a look:
(1+λ)normalization — the returned weights are(w_S + λ·w̃)/(1+λ)(sum to 1), matching the official implementation and LibMTL (g = (gw + λ·g0)/(1+λ)).lr=10,momentum=0.5,n_iter=20follow the official OptMN-Lab class;lamda=0.3follows the officialrun.shexperiments and LibMTL (their class default0.6is overridden to0.3in their own experiments).Includes unit tests, docs, a NOTICES entry, and a CHANGELOG entry.