Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoDiff] Invalid derivative type calculation for same-type conformance #78358

Open
asl opened this issue Dec 24, 2024 · 2 comments
Open

[AutoDiff] Invalid derivative type calculation for same-type conformance #78358

asl opened this issue Dec 24, 2024 · 2 comments
Labels
AutoDiff bug A deviation from expected or documented behavior. Also: expected but undesirable behavior.

Comments

@asl
Copy link
Contributor

asl commented Dec 24, 2024

Description

Consider the reproduction. If, instead of S.TangentVector == Double conformance in validateVJPWithError the more generic S.TangentVector : FloatingPoint is used, then the pullback value printed is correct (just comment out the line in reproduction and uncomment the preceding one). It turns out that the abstraction differences in derivatives are not taken into account, the function expects the pullback to return value indirectly, while it returns one direct.

Indeed, we are having:

  %28 = function_ref @$s4conf20validateVJPWithError2of2atyq_xYjrXE_xt16_Differentiation14DifferentiableRzSFR_AeFR_13TangentVectorAeFPQy_Rs_SdAGRtzr0_lF : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : FloatingPoint, τ_0_1 : Differentiable, τ_0_1 == τ_0_1.TangentVector, τ_0_0.TangentVector == Double> (@guaranteed @differentiable(reverse) @noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_0, τ_0_1>, @in_guaranteed τ_0_0) -> () // user: %29
  %29 = apply %28<Double, Double>(%25, %26) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : FloatingPoint, τ_0_1 : Differentiable, τ_0_1 == τ_0_1.TangentVector, τ_0_0.TangentVector == Double> (@guaranteed @differentiable(reverse) @noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_0, τ_0_1>, @in_guaranteed τ_0_0) -> ()

Note that validateVJPWithError accepts a differentiable function returning its result indirect (@out τ_0_1). We are passing the following function to it (%25):

  %20 = differentiable_function_extract [vjp] %9  // user: %22
  // function_ref thunk for @callee_guaranteed (@unowned Double) -> (@unowned Double, @owned @escaping @callee_guaranteed (@unowned Double) -> (@unowned Double))
  %21 = function_ref @$sS4dIegyd_Igydo_S2dxSdRi_zRi0_zlySdIsegnd_Iegnro_TR : $@convention(thin) (@in_guaranteed Double, @guaranteed @noescape @callee_guaranteed (Double) -> (Double, @owned @callee_guaranteed (Double) -> Double)) -> (@out Double, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Double for <Double>) // user: %22
  %22 = partial_apply [callee_guaranteed] %21(%20) : $@convention(thin) (@in_guaranteed Double, @guaranteed @noescape @callee_guaranteed (Double) -> (Double, @owned @callee_guaranteed (Double) -> Double)) -> (@out Double, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Double for <Double>) // user: %23
  %23 = convert_function %22 to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Double for <τ_0_2>) for <Double, Double, Double> // users: %34, %24
  %24 = convert_escape_to_noescape %23 to $@noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Double for <τ_0_2>) for <Double, Double, Double> // user: %25
  %25 = differentiable_function [parameters 0] [results 0] %14 with_derivative {%19, %24} // user: %29

Note that:

  • VJP type is $@noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Double for <τ_0_2>) for <Double, Double, Double>, so the return value of VJP is returned indirect. However, the pullback result value is returned direct
  • validateVJPWithError<A, B>(of:at:) expects pullback to return value indirectly:
  // function_ref valueWithPullback<A, B>(at:of:)
  %7 = function_ref @$s16_Differentiation17valueWithPullback2at2ofq_0B0_13TangentVectorQzAFQy_c8pullbacktx_q_xYjrXEtAA14DifferentiableRzAaJR_r0_lF : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable(reverse) @noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_0, τ_0_1>) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_1.TangentVector, τ_0_0.TangentVector>) // user: %8
  %8 = apply %7<S, T>(%5, %1, %0) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable(reverse) @noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_0, τ_0_1>) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_1.TangentVector, τ_0_0.TangentVector>) // user: %9
  %9 = convert_function %8 to $@callee_guaranteed (@in_guaranteed T) -> @out Double // user: %11
  // function_ref thunk for @escaping @callee_guaranteed (@in_guaranteed B) -> (@out Double)
  %10 = function_ref @$sq_SdIegnr_q_SdIegnd_16_Differentiation14DifferentiableRzSFR_AaBR_13TangentVectorAaBPQy_Rs_SdACRtzr0_lTR :$@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : FloatingPoint, τ_0_1 : Differentiable, τ_0_1 == τ_0_1.TangentVector, τ_0_0.TangentVector == Double> (@in_guaranteed τ_0_1, @guaranteed @callee_guaranteed (@in_guaranteed τ_0_1) -> @out Double) -> Double // user: %11
  %11 = partial_apply [callee_guaranteed] %10<S, T>(%9) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : FloatingPoint, τ_0_1 : Differentiable, τ_0_1 == τ_0_1.TangentVector, τ_0_0.TangentVector == Double (@in_guaranteed τ_0_1, @guaranteed @callee_guaranteed (@in_guaranteed τ_0_1) -> @out Double) -> Double // user: %12
  %12 = convert_function %11 to $@callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Double for <T> // user: %13

As a result, we are having an abstraction difference and a junk value is returned. At the same time, removing the same-type conformance yields the following VJP type which introduced necessary reabstraction conversion:

  %20 = differentiable_function_extract [vjp] %9  // user: %22
  // function_ref thunk for @callee_guaranteed (@unowned Double) -> (@unowned Double, @owned @escaping @callee_guaranteed (@unowned Double) -> (@unowned Double))
  %21 = function_ref @$sS4dIegyd_Igydo_S2dxq_Ri_zRi0_zRi__Ri0__r0_lyS2dIsegnr_Iegnro_TR : $@convention(thin)(@in_guaranteed Double, @guaranteed @noescape @callee_guaranteed (Double) -> (Double, @owned @callee_guaranteed (Double) -> Double)) -> (@out Double, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Double, Double>) // user: %22
  %22 = partial_apply [callee_guaranteed] %21(%20) : $@convention(thin) (@in_guaranteed Double, @guaranteed @noescape @callee_guaranteed (Double) -> (Double, @owned @callee_guaranteed (Double) -> Double)) -> (@out Double, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Double, Double>) // user: %23
  %23 = convert_function %22 to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Double, Double, Double, Double> // users: %34, %24
  %24 = convert_escape_to_noescape %23 to $@noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Double, Double, Double, Double> // user: %25

Reproduction

import _Differentiation

@differentiable(reverse)
@_silgen_name("_oneOverX")
func oneOverX(_ x: Double) -> Double {
    1 / x
}

@_silgen_name("_vjpOneOverX")
func _vjpOneOverX(_ x: Double) -> (value: Double, pullback: (Double) -> Double) {
    (
        value: 1 / x,
        pullback: { v in
            -v / (x * x)
        }
    )
}

@_silgen_name("_pp")
@inline(never)
func pp<T>(_ v : T) {
  print(v)
}

@inline(never)
func validateVJPWithError<S, T>(
    of function: @differentiable(reverse) (S) -> T,
    at point: S
) where S: Differentiable, T: Differentiable, T: FloatingPoint
    , T == T.TangentVector
    //, S.TangentVector: FloatingPoint
    , S.TangentVector == Double
    {
    let vwpb = valueWithPullback(at: point, of: function)
    let pullback = vwpb.pullback(.init(1))

    pp(pullback)
}

@_silgen_name("_testOneOverX")
func testOneOverX(_ x: Double) {
    validateVJPWithError(of: { x in oneOverX(x) }, 
                        at: x)
}

testOneOverX(10.0)

Expected behavior

Proper pullback value is printed. It would be also great if SIL verifier would catch this abstraction difference.

Environment

Swift version 6.2-dev (LLVM e404f8897f17aff, Swift b47b157)
Target: arm64-apple-macosx15.0

Additional information

No response

@asl asl added bug A deviation from expected or documented behavior. Also: expected but undesirable behavior. triage needed This issue needs more specific labels AutoDiff and removed triage needed This issue needs more specific labels labels Dec 24, 2024
@asl
Copy link
Contributor Author

asl commented Dec 24, 2024

Tagging @JaapWijnen @rxwei

@JaapWijnen
Copy link
Contributor

Thanks for looking into this one!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
AutoDiff bug A deviation from expected or documented behavior. Also: expected but undesirable behavior.
Projects
None yet
Development

No branches or pull requests

2 participants