diff --git a/CHANGELOG.md b/CHANGELOG.md index 0d5e80a6e..d907bd54e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,14 @@ changes that do not affect the user. ## [Unreleased] +### Changed + +- Changed how the Jacobians are computed when calling `backward` or `mtl_backward` with + `parallel_chunk_size=1` to not rely on `torch.autograd.vmap` in this case. Whenever `vmap` does + not support something (compiled functions, RNN on cuda, etc.), users should now be able to avoid + using `vmap` by calling `backward` or `mtl_backward` with `parallel_chunk_size=1` and + `retain_graph=True`. + ## [0.3.1] - 2024-12-21 ### Changed diff --git a/src/torchjd/autojac/_transform/jac.py b/src/torchjd/autojac/_transform/jac.py index 6650214c3..bb9559ab8 100644 --- a/src/torchjd/autojac/_transform/jac.py +++ b/src/torchjd/autojac/_transform/jac.py @@ -68,10 +68,23 @@ def get_vjp(grad_outputs: Sequence[Tensor]) -> Tensor: grads = _materialize(optional_grads, inputs=inputs) return torch.concatenate([grad.reshape([-1]) for grad in grads]) - # Because of a limitation of vmap, this breaks when some tensors have `retains_grad=True`. - # See https://pytorch.org/functorch/stable/ux_limitations.html for more information. - # This also breaks when some tensors have been produced by compiled functions. - grouped_jacobian_matrix = torch.vmap(get_vjp, chunk_size=self.chunk_size)(jac_outputs) + if self.chunk_size == 1: + # In this special case, we don't need vmap, and because of the issues of vmap, we're + # better off not using it. In most cases, this should be equivalent to the vmap call, + # but in cases where vmap breaks (compiled functions, RNN on cuda, etc.), this should + # still work. + rows = [] + for i in range(jac_outputs[0].shape[0]): + grad_outputs = [jac_output[i] for jac_output in jac_outputs] + gradient_vector = get_vjp(grad_outputs) + rows.append(gradient_vector) + grouped_jacobian_matrix = torch.vstack(rows) + else: + # Because of a limitation of vmap, this breaks when some tensors have + # `retains_grad=True`. See https://pytorch.org/functorch/stable/ux_limitations.html for + # more information. This also breaks when some tensors have been produced by compiled + # functions, and in some other cases (RNN on cuda, etc.). + grouped_jacobian_matrix = torch.vmap(get_vjp, chunk_size=self.chunk_size)(jac_outputs) lengths = [input.numel() for input in inputs] jacobian_matrices = _extract_sub_matrices(grouped_jacobian_matrix, lengths) diff --git a/tests/unit/autojac/_transform/test_jac.py b/tests/unit/autojac/_transform/test_jac.py index 2a858f007..f7eeb6541 100644 --- a/tests/unit/autojac/_transform/test_jac.py +++ b/tests/unit/autojac/_transform/test_jac.py @@ -1,5 +1,5 @@ import torch -from pytest import raises +from pytest import mark, raises from unit.conftest import DEVICE from torchjd.autojac._transform import Jac, Jacobians @@ -7,7 +7,8 @@ from ._dict_assertions import assert_tensor_dicts_are_close -def test_single_input(): +@mark.parametrize(["chunk_size", "retain_graph"], [(1, True), (3, True), (None, False)]) +def test_single_input(chunk_size: int | None, retain_graph: bool): """ Tests that the Jac transform works correctly for an example of multiple differentiation. Here, the function considered is: `y = [a1 * x, a2 * x]`. We want to compute the jacobians of `y` with @@ -20,7 +21,7 @@ def test_single_input(): y = torch.stack([a1 * x, a2 * x]) input = Jacobians({y: torch.eye(2, device=DEVICE)}) - jac = Jac(outputs=[y], inputs=[a1, a2], chunk_size=None) + jac = Jac(outputs=[y], inputs=[a1, a2], chunk_size=chunk_size, retain_graph=True) jacobians = jac(input) expected_jacobians = { @@ -31,7 +32,8 @@ def test_single_input(): assert_tensor_dicts_are_close(jacobians, expected_jacobians) -def test_empty_inputs_1(): +@mark.parametrize(["chunk_size", "retain_graph"], [(1, True), (3, True), (None, False)]) +def test_empty_inputs_1(chunk_size: int | None, retain_graph: bool): """ Tests that the Jac transform works correctly when the `inputs` parameter is an empty `Iterable`. """ @@ -41,7 +43,7 @@ def test_empty_inputs_1(): y = torch.stack([y1, y2]) input = Jacobians({y: torch.eye(2, device=DEVICE)}) - jac = Jac(outputs=[y], inputs=[], chunk_size=None) + jac = Jac(outputs=[y], inputs=[], chunk_size=chunk_size, retain_graph=True) jacobians = jac(input) expected_jacobians = {} @@ -49,7 +51,8 @@ def test_empty_inputs_1(): assert_tensor_dicts_are_close(jacobians, expected_jacobians) -def test_empty_inputs_2(): +@mark.parametrize(["chunk_size", "retain_graph"], [(1, True), (3, True), (None, False)]) +def test_empty_inputs_2(chunk_size: int | None, retain_graph: bool): """ Tests that the Jac transform works correctly when the `inputs` parameter is an empty `Iterable`. """ @@ -62,7 +65,7 @@ def test_empty_inputs_2(): y = torch.stack([y1, y2]) input = Jacobians({y: torch.eye(2, device=DEVICE)}) - jac = Jac(outputs=[y], inputs=[], chunk_size=None) + jac = Jac(outputs=[y], inputs=[], chunk_size=chunk_size, retain_graph=True) jacobians = jac(input) expected_jacobians = {} @@ -122,7 +125,8 @@ def test_two_levels(): assert_tensor_dicts_are_close(jacobians, expected_jacobians) -def test_multiple_outputs_1(): +@mark.parametrize(["chunk_size", "retain_graph"], [(1, True), (3, True), (None, False)]) +def test_multiple_outputs_1(chunk_size: int | None, retain_graph: bool): """ Tests that the Jac transform works correctly when the `outputs` contains 3 vectors. The input (jac_outputs) is not the same for all outputs, so that this test also checks that the @@ -143,7 +147,7 @@ def test_multiple_outputs_1(): jac_output3 = torch.cat([zeros_2x2, zeros_2x2, identity_2x2]) input = Jacobians({y1: jac_output1, y2: jac_output2, y3: jac_output3}) - jac = Jac(outputs=[y1, y2, y3], inputs=[a1, a2], chunk_size=None) + jac = Jac(outputs=[y1, y2, y3], inputs=[a1, a2], chunk_size=chunk_size, retain_graph=True) jacobians = jac(input) zero_scalar = torch.tensor(0.0, device=DEVICE) @@ -155,7 +159,8 @@ def test_multiple_outputs_1(): assert_tensor_dicts_are_close(jacobians, expected_jacobians) -def test_multiple_outputs_2(): +@mark.parametrize(["chunk_size", "retain_graph"], [(1, True), (3, True), (None, False)]) +def test_multiple_outputs_2(chunk_size: int | None, retain_graph: bool): """ Same as test_multiple_outputs_1 but with different jac_outputs, so the returned jacobians are of different shapes. @@ -175,7 +180,7 @@ def test_multiple_outputs_2(): jac_output3 = torch.stack([zeros_2, zeros_2, ones_2]) input = Jacobians({y1: jac_output1, y2: jac_output2, y3: jac_output3}) - jac = Jac(outputs=[y1, y2, y3], inputs=[a1, a2], chunk_size=None) + jac = Jac(outputs=[y1, y2, y3], inputs=[a1, a2], chunk_size=chunk_size, retain_graph=True) jacobians = jac(input) expected_jacobians = { diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index 5845ae3ea..a3fa0cfc2 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -2,6 +2,7 @@ import torch from pytest import mark, raises +from torch import nn from torch.testing import assert_close from unit._utils import ExceptionContext from unit.conftest import DEVICE @@ -26,11 +27,15 @@ def test_various_aggregators(aggregator: Aggregator): assert (a.grad is not None) and (a.shape == a.grad.shape) -@mark.parametrize("aggregator", [Mean(), UPGrad(), MGDA()]) -@mark.parametrize("shape", [(2, 3), (2, 6), (5, 8), (60, 55), (120, 143)]) +@mark.parametrize("aggregator", [Mean(), UPGrad()]) +@mark.parametrize("shape", [(2, 3), (2, 6), (5, 8), (20, 55)]) @mark.parametrize("manually_specify_inputs", [True, False]) +@mark.parametrize("chunk_size", [1, 3, None]) def test_value_is_correct( - aggregator: Aggregator, shape: tuple[int, int], manually_specify_inputs: bool + aggregator: Aggregator, + shape: tuple[int, int], + manually_specify_inputs: bool, + chunk_size: int | None, ): """ Tests that the .grad value filled by backward is correct in a simple example of matrix-vector @@ -46,7 +51,13 @@ def test_value_is_correct( else: inputs = None - backward([output], aggregator, inputs=inputs) + backward( + [output], + aggregator, + inputs=inputs, + retain_graph=True, + parallel_chunk_size=chunk_size, + ) assert_close(input.grad, aggregator(J)) @@ -203,3 +214,52 @@ def test_non_input_retaining_grad_fails(): with raises(RuntimeError): # Using such a BatchedTensor should result in an error _ = -b.grad + + +@mark.parametrize("chunk_size", [1, 3, None]) +def test_tensor_used_multiple_times(chunk_size: int | None): + """ + Tests that backward works correctly when one of the inputs is used multiple times. In this + setup, the autograd graph is still acyclic, but the graph of tensors used becomes cyclic. + """ + + a = torch.tensor(3.0, requires_grad=True, device=DEVICE) + b = 2.0 * a + c = a * b + d = a * c + e = a * d + aggregator = UPGrad() + + backward([d, e], aggregator=aggregator, parallel_chunk_size=chunk_size, retain_graph=True) + + expected_jacobian = torch.tensor( + [ + [2.0 * 3.0 * a**2], + [2.0 * 4.0 * a**3], + ], + device=DEVICE, + ) + + assert_close(a.grad, aggregator(expected_jacobian).squeeze()) + + +def test_rnn(): + """ + Tests that backward works for a very simple RNN, adapted from + [PyTorch's documentation](https://pytorch.org/docs/stable/generated/torch.nn.RNN.html). + """ + + rnn = nn.RNN(input_size=10, hidden_size=20, num_layers=2).to(device=DEVICE) + input = torch.randn(5, 3, 10, device=DEVICE) # Batch of 3 sequences of length 5 and of dim 10. + h0 = torch.randn(2, 3, 20, device=DEVICE) # Batch of 3 hidden states of 2 layers of dim 20. + output, _ = rnn(input, h0) # Output is of shape [5, 3, 20]. + target = torch.randn(5, 3, 20, device=DEVICE) # Batch of 3 sequences of len 5 and of dim 20. + losses = ((output - target) ** 2).sum(dim=[1, 2]) # 1 loss per sequence element. + aggregator = UPGrad() + + # It's necessary to avoid using vmap by setting the parallel_chunk_size to 1 because the cuda + # implementation of RNN is not supported by vmap. + backward(tensors=losses, aggregator=aggregator, parallel_chunk_size=1, retain_graph=True) + + for param in rnn.parameters(): + assert param.grad is not None and param.grad.shape == param.shape diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index c52ea643a..2dd026e09 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -1,7 +1,10 @@ from contextlib import nullcontext as does_not_raise +from itertools import chain import torch from pytest import mark, raises +from torch import nn +from torch.nn import BCELoss, MSELoss from torch.testing import assert_close from unit._utils import ExceptionContext from unit.conftest import DEVICE @@ -29,15 +32,17 @@ def test_various_aggregators(aggregator: Aggregator): assert (p.grad is not None) and (p.shape == p.grad.shape) -@mark.parametrize("aggregator", [Mean(), UPGrad(), MGDA()]) -@mark.parametrize("shape", [(2, 3), (2, 6), (5, 8), (60, 55), (120, 143)]) +@mark.parametrize("aggregator", [Mean(), UPGrad()]) +@mark.parametrize("shape", [(2, 3), (2, 6), (5, 8), (20, 55)]) @mark.parametrize("manually_specify_shared_params", [True, False]) @mark.parametrize("manually_specify_tasks_params", [True, False]) +@mark.parametrize("chunk_size", [1, 3, None]) def test_value_is_correct( aggregator: Aggregator, shape: tuple[int, int], manually_specify_shared_params: bool, manually_specify_tasks_params: bool, + chunk_size: int | None, ): """ Tests that the .grad value filled by mtl_backward is correct in a simple example of @@ -74,6 +79,8 @@ def test_value_is_correct( aggregator=aggregator, tasks_params=tasks_params, shared_params=shared_params, + retain_graph=True, + parallel_chunk_size=chunk_size, ) assert_close(p1.grad, f) @@ -592,3 +599,48 @@ def test_default_shared_params_overlapping_with_default_tasks_params_fails(): aggregator=UPGrad(), retain_graph=True, ) + + +def test_rnn(): + """ + Tests that mtl_backward works for simple multitask model whose feature extractor is an RNN + adapted from + [PyTorch's documentation](https://pytorch.org/docs/stable/generated/torch.nn.RNN.html). + + Here, we have a binary classification task and a 4-regressions task using the last hidden state + of the RNN as shared input features. + """ + + rnn = nn.RNN(input_size=10, hidden_size=20, num_layers=2).to(device=DEVICE) + cls_head = nn.Linear(40, 1).to(device=DEVICE) + reg_head = nn.Linear(40, 4).to(device=DEVICE) + + input = torch.randn(5, 3, 10, device=DEVICE) # Batch of 3 sequences of length 5 and of dim 10. + h0 = torch.randn(2, 3, 20, device=DEVICE) # Batch of 3 hidden states of 2 layers of dim 20. + _, hn = rnn(input, h0) # hn is of shape [2, 3, 20]. + features = hn.permute(1, 0, 2).reshape(3, -1) + cls_output = torch.sigmoid(cls_head(features)).squeeze() + reg_output = reg_head(features) + + cls_loss_fn = BCELoss() + reg_loss_fn = MSELoss() + + cls_target = torch.tensor([1.0, 0.0, 1.0], device=DEVICE) + reg_target = torch.randn(3, 4, device=DEVICE) + + cls_loss = cls_loss_fn(cls_output, cls_target) + reg_loss = reg_loss_fn(reg_output, reg_target) + losses = [cls_loss, reg_loss] + + # It's necessary to avoid using vmap by setting the parallel_chunk_size to 1 because the cuda + # implementation of RNN is not supported by vmap. + mtl_backward( + losses=losses, + features=features, + aggregator=UPGrad(), + parallel_chunk_size=1, + retain_graph=True, + ) + + for param in chain(rnn.parameters(), cls_head.parameters(), reg_head.parameters()): + assert param.grad is not None and param.grad.shape == param.shape