Skip to content

Commit

Permalink
Factor out emission into separate helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
asl committed Dec 10, 2024
1 parent 5a68861 commit b47b157
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 72 deletions.
17 changes: 17 additions & 0 deletions lib/SILGen/SILGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -2505,6 +2505,23 @@ class LLVM_LIBRARY_VISIBILITY SILGenFunction
CanSILFunctionType toType,
bool reorderSelf);

/// Emit conversion from T.TangentVector to Optional<T>.TangentVector.
ManagedValue
emitTangentVectorToOptionalTangentVector(SILLocation loc,
ManagedValue input,
CanType inputType,
CanType outputType,
SGFContext ctxt);

/// Emit conversion from Optional<T>.TangentVector to T.TangentVector.
ManagedValue
emitOptionalTangentVectorToTangentVector(SILLocation loc,
ManagedValue input,
CanType inputType,
CanType outputType,
SGFContext ctxt);


//===--------------------------------------------------------------------===//
// Back Deployment thunks
//===--------------------------------------------------------------------===//
Expand Down
166 changes: 94 additions & 72 deletions lib/SILGen/SILGenPoly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,90 @@ SILGenFunction::emitTransformExistential(SILLocation loc,
});
}

ManagedValue
SILGenFunction::emitTangentVectorToOptionalTangentVector(SILLocation loc,
ManagedValue input,
CanType inputType,
CanType outputType,
SGFContext ctxt) {
auto *optionalTanDecl = outputType.getNominalOrBoundGenericNominal();
// Look up the `Optional<T>.TangentVector.init` declaration.
auto initLookup =
optionalTanDecl->lookupDirect(DeclBaseName::createConstructor());
ConstructorDecl *constructorDecl = nullptr;
for (auto *candidate : initLookup) {
auto candidateModule = candidate->getModuleContext();
if (candidateModule->getName() ==
getASTContext().Id_Differentiation ||
candidateModule->isStdlibModule()) {
assert(!constructorDecl && "Multiple `Optional.TangentVector.init`s");
constructorDecl = cast<ConstructorDecl>(candidate);
#ifdef NDEBUG
break;
#endif
}
}
assert(constructorDecl && "No `Optional.TangentVector.init`");

// `Optional<T.TangentVector>`
CanType optionalOfWrappedTanType = inputType.wrapInOptionalType();

const TypeLowering &optTL = getTypeLowering(optionalOfWrappedTanType);
auto optVal = emitInjectOptional(loc, optTL, ctxt,
[&](SGFContext objectCtxt) {
return input;
});
auto *diffProto = getASTContext().getProtocol(KnownProtocolKind::Differentiable);
auto diffConf = lookupConformance(inputType, diffProto);
assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`");
ConcreteDeclRef initDecl(constructorDecl,
SubstitutionMap::get(constructorDecl->getGenericSignature(),
{inputType}, {diffConf}));
PreparedArguments args({AnyFunctionType::Param(optionalOfWrappedTanType)});
args.add(loc, RValue(*this, {optVal}, optionalOfWrappedTanType));

auto result = emitApplyAllocatingInitializer(loc, initDecl,
std::move(args), outputType, ctxt);
if (result.isInContext())
return ManagedValue::forInContext();
return std::move(result).getAsSingleValue(*this, loc);
}

ManagedValue
SILGenFunction::emitOptionalTangentVectorToTangentVector(SILLocation loc,
ManagedValue input,
CanType inputType,
CanType outputType,
SGFContext ctxt) {
// Optional<T>.TangentVector should be a struct with a single
// Optional<T.TangentVector> property. This is an implementation detail of
// OptionalDifferentiation.swift
// TODO: Maybe it would be better to have getters / setters here that we
// can call and hide this implementation detail?
StructDecl *optStructDecl = inputType.getStructOrBoundGenericStruct();
VarDecl *wrappedValueVar = nullptr;
if (optStructDecl) {
ArrayRef<VarDecl *> properties = optStructDecl->getStoredProperties();
wrappedValueVar = properties.size() == 1 ? properties[0] : nullptr;
}

EnumDecl *optDecl = wrappedValueVar ?
wrappedValueVar->getTypeInContext()->getEnumOrBoundGenericEnum() :
nullptr;

if (!optStructDecl || optDecl != getASTContext().getOptionalDecl())
llvm_unreachable("Unexpected type of Optional.TangentVector");

FormalEvaluationScope scope(*this);
auto wrappedVal = B.createStructExtract(loc, input, wrappedValueVar);
return emitCheckedGetOptionalValueFrom(loc, wrappedVal,
/*isImplicitUnwrap*/ true,
getTypeLowering(wrappedVal.getType()),
ctxt);
}



/// Apply this transformation to an arbitrary value.
RValue Transform::transform(RValue &&input,
AbstractionPattern inputOrigType,
Expand Down Expand Up @@ -689,51 +773,11 @@ ManagedValue Transform::transform(ManagedValue v,
// case when T == T.TangentVector)
auto inputTanSpace =
wrappedType->getAutoDiffTangentSpace(LookUpConformanceInModule());
if (inputTanSpace && inputTanSpace->getCanonicalType() == inputSubstType) {
auto *optionalTanDecl = outputSubstType.getNominalOrBoundGenericNominal();
// Look up the `Optional<T>.TangentVector.init` declaration.
auto initLookup =
optionalTanDecl->lookupDirect(DeclBaseName::createConstructor());
ConstructorDecl *constructorDecl = nullptr;
for (auto *candidate : initLookup) {
auto candidateModule = candidate->getModuleContext();
if (candidateModule->getName() ==
SGF. getASTContext().Id_Differentiation ||
candidateModule->isStdlibModule()) {
assert(!constructorDecl && "Multiple `Optional.TangentVector.init`s");
constructorDecl = cast<ConstructorDecl>(candidate);
#ifdef NDEBUG
break;
#endif
}
}
assert(constructorDecl && "No `Optional.TangentVector.init`");

// `T.TangentVector`
CanType wrappedTanType = inputTanSpace->getCanonicalType();
// `Optional<T.TangentVector>`
CanType optionalOfWrappedTanType = wrappedTanType.wrapInOptionalType();

const TypeLowering &optTL = SGF.getTypeLowering(optionalOfWrappedTanType);
auto optVal = SGF.emitInjectOptional(Loc, optTL, ctxt,
[&](SGFContext objectCtxt) {
return v;
});
auto *diffProto = SGF.getASTContext().getProtocol(KnownProtocolKind::Differentiable);
auto diffConf = lookupConformance(wrappedType, diffProto);
assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`");
ConcreteDeclRef initDecl(constructorDecl,
SubstitutionMap::get(constructorDecl->getGenericSignature(),
{wrappedType}, {diffConf}));
PreparedArguments args({AnyFunctionType::Param(optionalOfWrappedTanType)});
args.add(Loc, RValue(SGF, {optVal}, optionalOfWrappedTanType));

auto result = SGF.emitApplyAllocatingInitializer(Loc, initDecl,
std::move(args), outputSubstType, ctxt);
if (result.isInContext())
return ManagedValue::forInContext();
return std::move(result).getAsSingleValue(SGF, Loc);
}
if (inputTanSpace &&
inputTanSpace->getCanonicalType() == inputSubstType)
return SGF.emitTangentVectorToOptionalTangentVector(Loc, v,
inputSubstType, outputSubstType,
ctxt);
}

// - Optional<T>.TangentVector to T.TangentVector.
Expand All @@ -744,33 +788,11 @@ ManagedValue Transform::transform(ManagedValue v,
// case when T == T.TangentVector)
auto outputTanSpace =
wrappedType->getAutoDiffTangentSpace(LookUpConformanceInModule());
if (outputTanSpace && outputTanSpace->getCanonicalType() == outputSubstType) {
// Optional<T>.TangentVector should be a struct with a single
// Optional<T.TangentVector> property. This is an implementation detail of
// OptionalDifferentiation.swift
// TODO: Maybe it would be better to have getters / setters here that we
// can call and hide this implementation detail?
StructDecl *optStructDecl = inputSubstType.getStructOrBoundGenericStruct();
VarDecl *wrappedValueVar = nullptr;
if (optStructDecl) {
ArrayRef<VarDecl *> properties = optStructDecl->getStoredProperties();
wrappedValueVar = properties.size() == 1 ? properties[0] : nullptr;
}

EnumDecl *optDecl = wrappedValueVar ?
wrappedValueVar->getTypeInContext()->getEnumOrBoundGenericEnum() :
nullptr;

if (!optStructDecl || optDecl != SGF.getASTContext().getOptionalDecl())
llvm_unreachable("Unexpected type of Optional.TangentVector");

FormalEvaluationScope scope(SGF);
auto wrappedVal = SGF.B.createStructExtract(Loc, v, wrappedValueVar);
return SGF.emitCheckedGetOptionalValueFrom(Loc, wrappedVal,
/*isImplicitUnwrap*/ true,
SGF.getTypeLowering(wrappedVal.getType()),
ctxt);
}
if (outputTanSpace &&
outputTanSpace->getCanonicalType() == outputSubstType)
return SGF.emitOptionalTangentVectorToTangentVector(Loc, v,
inputSubstType, outputSubstType,
ctxt);
}

// Should have handled the conversion in one of the cases above.
Expand Down

0 comments on commit b47b157

Please sign in to comment.