Skip to content

Commit

Permalink
Void augmented (rust-lang#351)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Oct 5, 2021
1 parent f426b09 commit e97c0d1
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
7 changes: 2 additions & 5 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -980,10 +980,6 @@ class Enzyme : public ModulePass {
tape = Builder.CreateLoad(
Builder.CreatePointerCast(AL, PointerType::getUnqual(tapeType)));
}
llvm::errs() << *CI->getParent() << "\n";
llvm::errs() << *CI->getParent() << "\n";
llvm::errs() << *tape << "\n";
llvm::errs() << *tapeType << "\n";
assert(tape->getType() == tapeType);
args.push_back(tape);
}
Expand Down Expand Up @@ -1056,7 +1052,8 @@ class Enzyme : public ModulePass {
}
}

if (!diffret->getType()->isEmptyTy() && !diffret->getType()->isVoidTy()) {
if (!diffret->getType()->isEmptyTy() && !diffret->getType()->isVoidTy() &&
!CI->getType()->isEmptyTy() && !CI->getType()->isVoidTy()) {
if (diffret->getType() == CI->getType()) {
CI->replaceAllUsesWith(diffret);
} else if (mode == DerivativeMode::ReverseModePrimal) {
Expand Down
38 changes: 38 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/splitSize4.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s

; Function Attrs: noinline nounwind readnone uwtable
define double @tester(double* %x) {
entry:
%gep = getelementptr double, double* %x, i32 1
%y = load double, double* %x
%z = load double, double* %gep
%res = fmul fast double %y, %z
ret double %res
}

define void @test_derivative(double* %x, double* %dx) {
entry:
%size = call i64 (double (double*)*, ...) @__enzyme_augmentsize(double (double*)* nonnull @tester, metadata !"enzyme_dup")
%cache = alloca i8, i64 %size, align 1
call void (double (double*)*, ...) @__enzyme_augmentfwd(double (double*)* nonnull @tester, metadata !"enzyme_allocated", i64 %size, metadata !"enzyme_tape", i8* %cache, double* %x, double* %dx)
tail call void (double (double*)*, ...) @__enzyme_reverse(double (double*)* nonnull @tester, metadata !"enzyme_allocated", i64 %size, metadata !"enzyme_tape", i8* %cache, double* %x, double* %dx)
ret void
}

; Function Attrs: nounwind
declare void @__enzyme_augmentfwd(double (double*)*, ...)
declare i64 @__enzyme_augmentsize(double (double*)*, ...)
declare void @__enzyme_reverse(double (double*)*, ...)

; CHECK: define void @test_derivative(double* %x, double* %dx)
; CHECK-NEXT: entry:
; CHECK-NEXT: %cache = alloca i8, i64 16
; CHECK-NEXT: %0 = call { { double, double }, double } @augmented_tester(double* %x, double* %dx)
; CHECK-NEXT: %1 = extractvalue { { double, double }, double } %0, 0
; CHECK-NEXT: %2 = bitcast i8* %cache to { double, double }*
; CHECK-NEXT: store { double, double } %1, { double, double }* %2
; CHECK-NEXT: %3 = bitcast i8* %cache to { double, double }*
; CHECK-NEXT: %4 = load { double, double }, { double, double }* %3
; CHECK-NEXT: call void @diffetester(double* %x, double* %dx, double 1.000000e+00, { double, double } %4)
; CHECK-NEXT: ret void
; CHECK-NEXT: }

0 comments on commit e97c0d1

Please sign in to comment.