Skip to content

Commit

Permalink
[AutoDiff] Fix adjoints for loop-local active values
Browse files Browse the repository at this point in the history
  • Loading branch information
kovdan01 committed Dec 18, 2024
1 parent f2ad9f3 commit f2e68d3
Show file tree
Hide file tree
Showing 3 changed files with 355 additions and 12 deletions.
50 changes: 50 additions & 0 deletions lib/SILOptimizer/Differentiation/PullbackCloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2456,6 +2456,56 @@ bool PullbackCloner::Implementation::run() {
// Visit original blocks in post-order and perform differentiation
// in corresponding pullback blocks. If errors occurred, back out.
else {
LLVM_DEBUG(getADDebugStream()
<< "Begin search for adjoints of loop-local active values\n");
llvm::DenseMap<SILLoop *, llvm::DenseSet<SILValue>> loopLocalActiveValues;
for (auto *bb : originalBlocks) {
SILLoop *loop = vjpCloner.getLoopInfo()->getLoopFor(bb);
if (loop == nullptr)
continue;
SILBasicBlock *loopHeader = loop->getHeader();
SILBasicBlock *pbLoopHeader = getPullbackBlock(loopHeader);
LLVM_DEBUG(getADDebugStream()
<< "Original bb" << bb->getDebugID()
<< " belongs to a loop, original header bb"
<< loopHeader->getDebugID() << ", pullback header bb"
<< pbLoopHeader->getDebugID() << '\n');
builder.setInsertionPoint(pbLoopHeader);
auto &bbActiveValues = activeValues[bb];
for (SILValue bbActiveValue : bbActiveValues) {
if (vjpCloner.getLoopInfo()->getLoopFor(
bbActiveValue->getParentBlock()) != loop) {
LLVM_DEBUG(
getADDebugStream()
<< "The following active value is NOT loop-local, skipping: "
<< bbActiveValue);
continue;
}
auto [_, wasInserted] =
loopLocalActiveValues[loop].insert(bbActiveValue);
LLVM_DEBUG(getADDebugStream()
<< "The following active value is loop-local, "
<< (wasInserted ? "zeroing its adjoint in loop header: "
: "but its adjoint was already zeroed in "
"loop header, skipping: ")
<< bbActiveValue);
if (!wasInserted)
continue;
if (getTangentValueCategory(bbActiveValue) ==
SILValueCategory::Object) {
setAdjointValue(bb, bbActiveValue,
makeZeroAdjointValue(getRemappedTangentType(
bbActiveValue->getType())));
} else {
assert(getTangentValueCategory(bbActiveValue) ==
SILValueCategory::Address);
getAdjointBuffer(bb, bbActiveValue);
}
}
}
LLVM_DEBUG(getADDebugStream()
<< "End search for adjoints of loop-local active values\n");

for (auto *bb : originalBlocks) {
visitSILBasicBlock(bb);
if (errorOccurred)
Expand Down
Loading

0 comments on commit f2e68d3

Please sign in to comment.