Skip to content

Commit

Permalink
Emit reabstraction thunks for implicit conversions between T.TangentT…
Browse files Browse the repository at this point in the history
…ype and Optional<T>.TangentType

Fixes #77924
  • Loading branch information
asl committed Dec 10, 2024
1 parent 23c577d commit 49b7a36
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 2 deletions.
98 changes: 98 additions & 0 deletions lib/SILGen/SILGenPoly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
//===----------------------------------------------------------------------===//

#define DEBUG_TYPE "silgen-poly"
#include "ArgumentSource.h"
#include "ExecutorBreadcrumb.h"
#include "FunctionInputGenerator.h"
#include "Initialization.h"
Expand Down Expand Up @@ -675,6 +676,103 @@ ManagedValue Transform::transform(ManagedValue v,
return std::move(result).getAsSingleValue(SGF, Loc);
}

// - T.TangentVector to Optional<T>.TangentVector
// Optional<T>.TangentVector is a struct wrapping Optional<T.TangentVector>
// So we just need to call appropriate .init on it.
// However, we might have T.TangentVector == T, so we need to calculate all
// required types first.
if (CanType optionalTy = outputSubstType.getNominalParent(); // `Optional<T>`
optionalTy && (bool)optionalTy.getOptionalObjectType()) {
// `T`
CanType wrappedType = optionalTy.getOptionalObjectType();
// Check that T.TangentVector is indeed inputSubstType (this also handles
// case when T == T.TangentVector)
auto inputTanSpace =
wrappedType->getAutoDiffTangentSpace(LookUpConformanceInModule());
if (inputTanSpace && inputTanSpace->getCanonicalType() == inputSubstType) {
auto *optionalTanDecl = outputSubstType.getNominalOrBoundGenericNominal();
// Look up the `Optional<T>.TangentVector.init` declaration.
auto initLookup =
optionalTanDecl->lookupDirect(DeclBaseName::createConstructor());
ConstructorDecl *constructorDecl = nullptr;
for (auto *candidate : initLookup) {
auto candidateModule = candidate->getModuleContext();
if (candidateModule->getName() ==
SGF. getASTContext().Id_Differentiation ||
candidateModule->isStdlibModule()) {
assert(!constructorDecl && "Multiple `Optional.TangentVector.init`s");
constructorDecl = cast<ConstructorDecl>(candidate);
#ifdef NDEBUG
break;
#endif
}
}
assert(constructorDecl && "No `Optional.TangentVector.init`");

// `T.TangentVector`
CanType wrappedTanType = inputTanSpace->getCanonicalType();
// `Optional<T.TangentVector>`
CanType optionalOfWrappedTanType = wrappedTanType.wrapInOptionalType();

const TypeLowering &optTL = SGF.getTypeLowering(optionalOfWrappedTanType);
auto optVal = SGF.emitInjectOptional(Loc, optTL, ctxt,
[&](SGFContext objectCtxt) {
return v;
});
auto *diffProto = SGF.getASTContext().getProtocol(KnownProtocolKind::Differentiable);
auto diffConf = lookupConformance(wrappedType, diffProto);
assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`");
ConcreteDeclRef initDecl(constructorDecl,
SubstitutionMap::get(constructorDecl->getGenericSignature(),
{wrappedType}, {diffConf}));
PreparedArguments args({AnyFunctionType::Param(optionalOfWrappedTanType)});
args.add(Loc, RValue(SGF, {optVal}, optionalOfWrappedTanType));

auto result = SGF.emitApplyAllocatingInitializer(Loc, initDecl,
std::move(args), outputSubstType, ctxt);
if (result.isInContext())
return ManagedValue::forInContext();
return std::move(result).getAsSingleValue(SGF, Loc);
}
}

// - Optional<T>.TangentVector to T.TangentVector.
if (CanType optionalTy = inputSubstType.getNominalParent(); // `Optional<T>`
optionalTy && (bool)optionalTy.getOptionalObjectType()) {
CanType wrappedType = optionalTy.getOptionalObjectType(); // `T`
// Check that T.TangentVector is indeed outputSubstType (this also handles
// case when T == T.TangentVector)
auto outputTanSpace =
wrappedType->getAutoDiffTangentSpace(LookUpConformanceInModule());
if (outputTanSpace && outputTanSpace->getCanonicalType() == outputSubstType) {
// Optional<T>.TangentVector should be a struct with a single
// Optional<T.TangentVector> property. This is an implementation detail of
// OptionalDifferentiation.swift
// TODO: Maybe it would be better to have getters / setters here that we
// can call and hide this implementation detail?
StructDecl *optStructDecl = inputSubstType.getStructOrBoundGenericStruct();
VarDecl *wrappedValueVar = nullptr;
if (optStructDecl) {
ArrayRef<VarDecl *> properties = optStructDecl->getStoredProperties();
wrappedValueVar = properties.size() == 1 ? properties[0] : nullptr;
}

EnumDecl *optDecl = wrappedValueVar ?
wrappedValueVar->getTypeInContext()->getEnumOrBoundGenericEnum() :
nullptr;

if (!optStructDecl || optDecl != SGF.getASTContext().getOptionalDecl())
llvm_unreachable("Unexpected type of Optional.TangentVector");

FormalEvaluationScope scope(SGF);
auto wrappedVal = SGF.B.createStructExtract(Loc, v, wrappedValueVar);
return SGF.emitCheckedGetOptionalValueFrom(Loc, wrappedVal,
/*isImplicitUnwrap*/ true,
SGF.getTypeLowering(wrappedVal.getType()),
ctxt);
}
}

// Should have handled the conversion in one of the cases above.
v.dump();
llvm_unreachable("Unhandled transform?");
Expand Down
15 changes: 13 additions & 2 deletions lib/Sema/CSApply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#include "clang/Sema/TemplateDeduction.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/SaveAndRestore.h"
Expand Down Expand Up @@ -7499,8 +7500,18 @@ Expr *ExprRewriter::coerceToType(Expr *expr, Type toType,
fromEI.intoBuilder()
.withDifferentiabilityKind(toEI.getDifferentiabilityKind())
.build();
fromFunc = FunctionType::get(toFunc->getParams(), fromFunc->getResult(),
newEI);
SmallVector<AnyFunctionType::Param, 4> params(fromFunc->getParams());
assert(params.size() == toFunc->getParams().size() && "unexpected @differentiable conversion");
// Propagate @noDerivate from target function type
for (auto paramAndIndex : llvm::enumerate(toFunc->getParams())) {
if (!paramAndIndex.value().isNoDerivative())
continue;

auto &param = params[paramAndIndex.index()];
param = param.withFlags(param.getParameterFlags().withNoDerivative(true));
}

fromFunc = FunctionType::get(params, fromFunc->getResult(), newEI);
switch (toEI.getDifferentiabilityKind()) {
// TODO: Ban `Normal` and `Forward` cases.
case DifferentiabilityKind::Normal:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: %target-swift-frontend -emit-sil -verify %s

// https://github.com/swiftlang/swift/issues/77871
// Ensure we are correctl generating reabstraction thunks for Double <-> Optional<Double>
// conversion for derivatives: for differential and pullback we need
// to emit thunks to convert T.TangentVector <-> Optional<T>.TangentVector.

import _Differentiation

@differentiable(reverse)
func testFunc(_ x: Double?) -> Double? {
x! * x! * x!
}
print(pullback(at: 1.0, of: testFunc)(.init(1.0)) == 3.0)

0 comments on commit 49b7a36

Please sign in to comment.