Skip to content

Commit

Permalink
[AutoDiff] Fix adjoints for loop-local active values
Browse files Browse the repository at this point in the history
  • Loading branch information
kovdan01 committed Dec 18, 2024
1 parent f2ad9f3 commit b17b3c3
Show file tree
Hide file tree
Showing 4 changed files with 403 additions and 22 deletions.
50 changes: 50 additions & 0 deletions lib/SILOptimizer/Differentiation/PullbackCloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2456,6 +2456,56 @@ bool PullbackCloner::Implementation::run() {
// Visit original blocks in post-order and perform differentiation
// in corresponding pullback blocks. If errors occurred, back out.
else {
LLVM_DEBUG(getADDebugStream()
<< "Begin search for adjoints of loop-local active values\n");
llvm::DenseMap<SILLoop *, llvm::DenseSet<SILValue>> loopLocalActiveValues;
for (auto *bb : originalBlocks) {
SILLoop *loop = vjpCloner.getLoopInfo()->getLoopFor(bb);
if (loop == nullptr)
continue;
SILBasicBlock *loopHeader = loop->getHeader();
SILBasicBlock *pbLoopHeader = getPullbackBlock(loopHeader);
LLVM_DEBUG(getADDebugStream()
<< "Original bb" << bb->getDebugID()
<< " belongs to a loop, original header bb"
<< loopHeader->getDebugID() << ", pullback header bb"
<< pbLoopHeader->getDebugID() << '\n');
builder.setInsertionPoint(pbLoopHeader);
auto &bbActiveValues = activeValues[bb];
for (SILValue bbActiveValue : bbActiveValues) {
if (vjpCloner.getLoopInfo()->getLoopFor(
bbActiveValue->getParentBlock()) != loop) {
LLVM_DEBUG(
getADDebugStream()
<< "The following active value is NOT loop-local, skipping: "
<< bbActiveValue);
continue;
}
auto [_, wasInserted] =
loopLocalActiveValues[loop].insert(bbActiveValue);
LLVM_DEBUG(getADDebugStream()
<< "The following active value is loop-local, "
<< (wasInserted ? "zeroing its adjoint in loop header: "
: "but its adjoint was already zeroed in "
"loop header, skipping: ")
<< bbActiveValue);
if (!wasInserted)
continue;
if (getTangentValueCategory(bbActiveValue) ==
SILValueCategory::Object) {
setAdjointValue(bb, bbActiveValue,
makeZeroAdjointValue(getRemappedTangentType(
bbActiveValue->getType())));
} else {
assert(getTangentValueCategory(bbActiveValue) ==
SILValueCategory::Address);
getAdjointBuffer(bb, bbActiveValue);
}
}
}
LLVM_DEBUG(getADDebugStream()
<< "End search for adjoints of loop-local active values\n");

for (auto *bb : originalBlocks) {
visitSILBasicBlock(bb);
if (errorOccurred)
Expand Down
20 changes: 10 additions & 10 deletions test/AutoDiff/SILOptimizer/pullback_generation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -183,18 +183,18 @@ func f4(a: NonTrivial) -> Float {

// CHECK-LABEL: sil private [ossa] @$s19pullback_generation2f41aSfAA10NonTrivialV_tFTJpSpSr : $@convention(thin) (Float, @guaranteed Builtin.NativeObject) -> @owned NonTrivial {
// CHECK: bb5(%67 : @owned $NonTrivial, %68 : $Float, %69 : @owned $(predecessor: _AD__$s19pullback_generation2f41aSfAA10NonTrivialV_tF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (@inout Float) -> Float)):
// CHECK: %88 = alloc_stack $NonTrivial
// CHECK: %74 = alloc_stack $NonTrivial

// Non-trivial value must be copied or there will be a
// double consume when all owned parameters are destroyed
// at the end of the basic block.
// CHECK: %89 = copy_value %67 : $NonTrivial

// CHECK: store %89 to [init] %88 : $*NonTrivial
// CHECK: %91 = struct_element_addr %88 : $*NonTrivial, #NonTrivial.x
// CHECK: %92 = alloc_stack $Float
// CHECK: store %86 to [trivial] %92 : $*Float
// CHECK: %94 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: %95 = metatype $@thick Float.Type
// CHECK: %96 = apply %94<Float>(%91, %92, %95) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: %75 = copy_value %67 : $NonTrivial

// CHECK: store %75 to [init] %74 : $*NonTrivial
// CHECK: %77 = struct_element_addr %74 : $*NonTrivial, #NonTrivial.x
// CHECK: %78 = alloc_stack $Float
// CHECK: store %72 to [trivial] %78 : $*Float
// CHECK: %80 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: %81 = metatype $@thick Float.Type
// CHECK: %82 = apply %80<Float>(%77, %78, %81) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: destroy_value %67 : $NonTrivial
Loading

0 comments on commit b17b3c3

Please sign in to comment.