Skip to content

fix(zygote ext): handle NamedTuple cotangents in vofa_u_adjoint (EnsembleSolution end-time grads)#623

Draft
ChrisRackauckas-Claude wants to merge 1 commit into
SciML:masterfrom
ChrisRackauckas-Claude:fix-namedtuple-cotangent-ensemblesolution
Draft

fix(zygote ext): handle NamedTuple cotangents in vofa_u_adjoint (EnsembleSolution end-time grads)#623
ChrisRackauckas-Claude wants to merge 1 commit into
SciML:masterfrom
ChrisRackauckas-Claude:fix-namedtuple-cotangent-ensemblesolution

Conversation

@ChrisRackauckas-Claude

Copy link
Copy Markdown
Contributor

Please ignore until reviewed by @ChrisRackauckas.

Problem

Reverse-mode (Zygote) differentiation through an EnsembleSolution crashes when the loss reads only a scalar field of each trajectory's solution — e.g. the halting time produced by a terminate! callback:

T_stars = [sol.t[end] for sol in ensemble_sol.u]
mean(T_stars)

with

ERROR: MethodError: no method matching ndims(::@NamedTuple{…})
  [1] RecursiveArrayTools.VectorOfArray(vec::Vector{@NamedTuple{…}})
  [2] vofa_u_adjoint(d::Vector{@NamedTuple{…}}, A::EnsembleSolution{…})

ForwardDiff works; Zygote does not.

Cause

EnsembleSolution <: AbstractVectorOfArray, so accessing .u hits the literal_getproperty(::AbstractVectorOfArray, ::Val{:u}) adjoint → vofa_u_adjoint, which rewraps the per-element cotangents in a VectorOfArray. When only a scalar field of each per-trajectory ODESolution is differentiated, Zygote's cotangent for each element is a structural NamedTuple ((u = nothing, t = […])), not an array. VectorOfArray(::Vector{NamedTuple}) then calls ndims on a NamedTuple, which has no method.

This is the structural-tangent analogue of the ZeroTangent/size(::ZeroTangent) case fixed in #606 (same function, sibling failure mode) — which is likely why the report looked handled but wasn't.

Fix

In both vofa_u_adjoint methods, detect when the mapped cotangents are not array-like and pass them through as a plain vector instead of forcing a VectorOfArray/DiffEqArray reconstruction.

Test

Added to test/AD/adjoints.jl, mirroring the #606 test: feeds NamedTuple cotangents through vofa_u_adjoint for both VectorOfArray and DiffEqArray, and a mixed NamedTuple/nothing/ZeroTangent case. Verified locally that the new test reproduces the exact ndims(::NamedTuple) error on the unmodified source and passes with the fix; the full AD/adjoints.jl suite passes.

Scope / not addressed

This fixes the crash so the cotangent flows through. Whether GaussAdjoint (or other solve adjoints) then propagates the event-time (sol.t[end]) sensitivity correctly through the ContinuousCallback/terminate! is a separate, downstream concern (the substance of SciML/DifferentialEquations.jl#1149) and is not claimed to be solved here.

Fixes SciML/DifferentialEquations.jl#1149

🤖 Generated with Claude Code

`vofa_u_adjoint` rewrapped the per-element cotangents of an
`AbstractVectorOfArray`/`AbstractDiffEqArray` in a `VectorOfArray`/`DiffEqArray`.
For an `EnsembleSolution` whose elements are solution objects, differentiating
only a scalar field (e.g. `sol.t[end]`) makes Zygote produce structural
`NamedTuple` cotangents per trajectory. These have no `ndims`, so the
reconstruction threw `MethodError: no method matching ndims(::NamedTuple)`.

Detect non-array cotangents and pass them through as a plain vector instead.
This extends the earlier `ZeroTangent`/`AbstractZero` handling (SciML#606) to the
structural-tangent case.

Fixes SciML/DifferentialEquations.jl#1149

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Zygote Cant differentiate through EnsembleSolution end times when using terminate! callbacks

2 participants