Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjoint for active values in loops are just wrong #78264

Open
asl opened this issue Dec 18, 2024 · 2 comments
Open

Adjoint for active values in loops are just wrong #78264

asl opened this issue Dec 18, 2024 · 2 comments
Labels
AutoDiff bug A deviation from expected or documented behavior. Also: expected but undesirable behavior.

Comments

@asl
Copy link
Contributor

asl commented Dec 18, 2024

Description

Kudos to @kovdan01 for initial analysis of this issue.

It turns out that adjoints for active values in loops are just plain wrong. Consider the reproducer. As one case see, the gradient for repeat_while_loop is wrong, while gradient for while_loop is correct. Even more, if we'd replace the code in loop by result *= x then repeat_while_loop case will start working.

Why it is so?

The loop body for repeat_while_loop looks like as follows (removed loop condition calculation for brevity):

  br bb1                                          // id: %10

bb1:                                              // Preds: bb2 bb0
  %11 = metatype $@thin Float.Type                // user: %16
  %12 = begin_access [read] [static] %2           // users: %14, %13
  %13 = load [trivial] %12                        // user: %16
  end_access %12                                  // id: %14
  // function_ref static Float.* infix(_:_:)
  %15 = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %16
  %16 = apply %15(%13, %0, %11) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %18
  %17 = begin_access [modify] [static] %2         // users: %18, %19
  store %16 to [trivial] %17                      // id: %18
  end_access %17                                  // id: %19
...
  cond_br %39, bb2, bb3                           // id: %40

bb2:                                              // Preds: bb1
  br bb1
bb3:                                              // Preds: bb1
...

The key thing here is active %13 (which is essentially a result value), so we need to generate adjoint for it. AutoDiff code uses notion of "adjoint for value X in basic block Y. This is fine for code without loops. And is just plain wrong for values inside loops as it should be "adjoint for value X in basic block Y on loop iteration Z. The values are different at different loop iterations. Thus their adjoints should be distinct as well. Without this we're ending with artificial adjoint accumulations (since single adjoint value is shared between loop iterations) and wrong results.

So, when generating pullback for this loop body we need to ensure that initial value for adjoint of %13 is zero on each iteration. And then perform the usual pullback cloning that involves adjoint value generation and accumulation. We don't do this, so essentially we're accumulating into adjoint from the previous loop iteration.

Sure, if things are so broken, why we have not noticed this before? I would say: coincidence.

For the code like result *= x we do not have these extra active values, Float.*= takes adjoint buffer as an inout argument and perform proper adjoint generation there.

while / for case is more interesting. Here the code looks like as follows:

  br bb1                                          // id: %10

bb1:                                              // Preds: bb2 bb0
...
  cond_br %21, bb2, bb3                           // id: %22

bb2:                                              // Preds: bb1
  %23 = metatype $@thin Float.Type                // user: %28
  %24 = begin_access [read] [static] %2           // users: %26, %25
  %25 = load [trivial] %24                        // user: %28
  end_access %24                                  // id: %26
  // function_ref static Float.* infix(_:_:)
  %27 = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %28
  %28 = apply %27(%25, %0, %23) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %30
  %29 = begin_access [modify] [static] %2         // users: %30, %31
  store %28 to [trivial] %29                      // id: %30
  end_access %29                                  // id: %31
...
  br bb1                                          // id: %41

bb3:                                              // Preds: bb1
  %42 = begin_access [read] [static] %2           // users: %44, %43
  %43 = load [trivial] %42                        // user: %47
  end_access %42                                  // id: %44

So, we're having loop header (bb1) first, then loop body and finally the code after loop bb3. Now, we're having a code that propagates adjoints of active values into predecessor BBs while doing function traverse in reverse post-order. Here, we first visit bb3, then bb1. Inside bb1 we're realizing that there are active values (%25 and %28) in predecessor bb2, so we are taking their adjoints in bb1 and propagating them into bb2. Since no adjoints were defined before, they will be zero initialized and further propagated. And since coincidentally it is a loop header, we're ending into zero-initializing them in each loop iteration in a pullback as pullback to loop header will be executed after loop body. Everything magically works.

The situation with repeat loop is in reverse, there is no "loop header" BB in the common sense, there is a "loop footer" instead fused into loop body. So, the adjoints for %13 and %16 will be first zero-initialized in pullback block corresponding to bb3 and then further propagated to bb1. So, no zero-initialization on each loop iteration, adjoint values will be reused from previous loop iteration, and wrong results will be provided.

It seems to me that we need to perform explicit adjoint zeroing inside loop headers in pullback cloner

Reproduction

import _Differentiation

func repeat_while_loop(_ x: Float) -> Float {
    var result = x
    var i = 0
    repeat {
      result = result * x
      i += 1
    } while i < 2
    return result
}

func while_loop(_ x: Float) -> Float {
    var result = x
    var i = 0
    while i < 2 {
      result = result * x
      i += 1
    }
    return result
}

print(valueWithGradient(at: 2, of: repeat_while_loop))
print(valueWithGradient(at: 2, of: while_loop))

Expected behavior

Correct gradient calculation for both cases

Environment

Swift version 6.2-dev (LLVM e404f8897f17aff, Swift 5a68861)
Target: arm64-apple-macosx15.0

Additional information

No response

@asl asl added bug A deviation from expected or documented behavior. Also: expected but undesirable behavior. AutoDiff labels Dec 18, 2024
@asl
Copy link
Contributor Author

asl commented Dec 18, 2024

Tagging @kovdan01 @JaapWijnen @rxwei @dan-zheng

@asl
Copy link
Contributor Author

asl commented Dec 18, 2024

Looks like one way of doing things is as follows:

  • For each basic block identify corresponding innermost loop header
  • Then for each active value inside this loop (but not in nested loops), explicitly zero corresponding adjoint in the loop header

Given that in the pullback loop header will be executed after corresponding loop body, we will effectively zero out adjoints of active values after each iteration. Values "reused" from previous iterations are propagated via phi-nodes and therefore will be unaffected, only those defined in loop body will be affected.

kovdan01 added a commit to kovdan01/swift that referenced this issue Dec 18, 2024
kovdan01 added a commit to kovdan01/swift that referenced this issue Dec 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
AutoDiff bug A deviation from expected or documented behavior. Also: expected but undesirable behavior.
Projects
None yet
Development

No branches or pull requests

1 participant