You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Consider the reproduction. If, instead of S.TangentVector == Double conformance in validateVJPWithError the more generic S.TangentVector : FloatingPoint is used, then the pullback value printed is correct (just comment out the line in reproduction and uncomment the preceding one). It turns out that the abstraction differences in derivatives are not taken into account, the function expects the pullback to return value indirectly, while it returns one direct.
Note that validateVJPWithError accepts a differentiable function returning its result indirect (@out τ_0_1). We are passing the following function to it (%25):
VJP type is $@noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Double for <τ_0_2>) for <Double, Double, Double>, so the return value of VJP is returned indirect. However, the pullback result value is returned direct
validateVJPWithError<A, B>(of:at:) expects pullback to return value indirectly:
As a result, we are having an abstraction difference and a junk value is returned. At the same time, removing the same-type conformance yields the following VJP type which introduced necessary reabstraction conversion:
import _Differentiation
@differentiable(reverse)@_silgen_name("_oneOverX")func oneOverX(_ x:Double)->Double{1/ x
}@_silgen_name("_vjpOneOverX")func _vjpOneOverX(_ x:Double)->(value:Double, pullback:(Double)->Double){(
value:1/ x,
pullback:{ v in-v /(x * x)})}@_silgen_name("_pp")@inline(never)func pp<T>(_ v :T){print(v)}@inline(never)func validateVJPWithError<S, T>(
of function:@differentiable(reverse)(S)->T,
at point:S)where S:Differentiable, T:Differentiable, T:FloatingPoint, T ==T.TangentVector
//, S.TangentVector: FloatingPoint
, S.TangentVector ==Double{letvwpb=valueWithPullback(at: point, of: function)letpullback= vwpb.pullback(.init(1))pp(pullback)}@_silgen_name("_testOneOverX")func testOneOverX(_ x:Double){validateVJPWithError(of:{ x inoneOverX(x)},
at: x)}testOneOverX(10.0)
Expected behavior
Proper pullback value is printed. It would be also great if SIL verifier would catch this abstraction difference.
Environment
Swift version 6.2-dev (LLVM e404f8897f17aff, Swift b47b157)
Target: arm64-apple-macosx15.0
Additional information
No response
The text was updated successfully, but these errors were encountered:
asl
added
bug
A deviation from expected or documented behavior. Also: expected but undesirable behavior.
triage needed
This issue needs more specific labels
AutoDiff
and removed
triage needed
This issue needs more specific labels
labels
Dec 24, 2024
Description
Consider the reproduction. If, instead of
S.TangentVector == Double
conformance invalidateVJPWithError
the more genericS.TangentVector : FloatingPoint
is used, then the pullback value printed is correct (just comment out the line in reproduction and uncomment the preceding one). It turns out that the abstraction differences in derivatives are not taken into account, the function expects the pullback to return value indirectly, while it returns one direct.Indeed, we are having:
Note that
validateVJPWithError
accepts a differentiable function returning its result indirect (@out τ_0_1
). We are passing the following function to it (%25
):Note that:
$@noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Double for <τ_0_2>) for <Double, Double, Double>
, so the return value of VJP is returned indirect. However, the pullback result value is returned directvalidateVJPWithError<A, B>(of:at:)
expects pullback to return value indirectly:As a result, we are having an abstraction difference and a junk value is returned. At the same time, removing the same-type conformance yields the following VJP type which introduced necessary reabstraction conversion:
Reproduction
Expected behavior
Proper pullback value is printed. It would be also great if SIL verifier would catch this abstraction difference.
Environment
Swift version 6.2-dev (LLVM e404f8897f17aff, Swift b47b157)
Target: arm64-apple-macosx15.0
Additional information
No response
The text was updated successfully, but these errors were encountered: