Skip to content

Commit

Permalink
Opt: Get rid of the LiftTry phase; instead handle things in the back-…
Browse files Browse the repository at this point in the history
…end. (#18619)

When we enter a `try-catch` at the JVM level, we have to make sure that
the stack is empty. That's because, upon exception, the JVM wipes the
stack, and we must not lose operands that are already on the stack that
we will still use.

Previously, this was achieved with a transformation phase, `LiftTry`,
which lifted problematic `try-catch`es in local `def`s, called
`liftedTree$x`. It analyzed the tree to predict which `try-catch`es
would execute on a non-empty stack when eventually compiled to the JVM.

This approach has several shortcomings.

It exhibits performance cliffs, as the generated def can then cause more
variables to be boxed in to `XRef`s. These were the only extra defs
created for implementation reasons rather than for language reasons. As
a user of the language, it is hard to predict when such a lifted def
will be needed.

The additional `liftedTree` methods also show up on stack traces and
obfuscate them. Debugging can be severely hampered as well.

Phases executing after `LiftTry`, notably `CapturedVars`, also had to
take care not to create more problematic situations as a result of their
transformations, which is hard to predict and to remember.

Finally, Scala.js and Scala Native do not have the same restriction, so
they received suboptimal code for no reason.

In this commit, we entirely remove the `LiftTry` phase. Instead, we
enhance the JVM back-end to deal with the situation. When starting a
`try-catch` on a non-empty stack, we stash the entire contents of the
stack into local variables. After the `try-catch`, we pop all those
local variables back onto the stack. We also null out the leftover vars
not to prevent garbage collection.

This new approach solves all of the problems mentioned above.

---

This could be back-ported to Scala 2 if there is interest.

/cc @adpi2 who wanted this to improve debugging.
  • Loading branch information
sjrd authored Oct 4, 2023
2 parents 76df8d9 + 52e8e74 commit 1bf0f6d
Show file tree
Hide file tree
Showing 13 changed files with 190 additions and 206 deletions.
113 changes: 62 additions & 51 deletions compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,14 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {

tree match {
case Assign(lhs @ DesugaredSelect(qual, _), rhs) =>
val savedStackHeight = stackHeight
val savedStackSize = stack.recordSize()
val isStatic = lhs.symbol.isStaticMember
if (!isStatic) {
genLoadQualifier(lhs)
stackHeight += 1
val qualTK = genLoad(qual)
stack.push(qualTK)
}
genLoad(rhs, symInfoTK(lhs.symbol))
stackHeight = savedStackHeight
stack.restoreSize(savedStackSize)
lineNumber(tree)
// receiverClass is used in the bytecode to access the field. using sym.owner may lead to IllegalAccessError
val receiverClass = qual.tpe.typeSymbol
Expand Down Expand Up @@ -150,9 +150,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
}

genLoad(larg, resKind)
stackHeight += resKind.size
stack.push(resKind)
genLoad(rarg, if (isShift) INT else resKind)
stackHeight -= resKind.size
stack.pop()

(code: @switch) match {
case ADD => bc add resKind
Expand Down Expand Up @@ -189,19 +189,19 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
if (isArrayGet(code)) {
// load argument on stack
assert(args.length == 1, s"Too many arguments for array get operation: $tree");
stackHeight += 1
stack.push(k)
genLoad(args.head, INT)
stackHeight -= 1
stack.pop()
generatedType = k.asArrayBType.componentType
bc.aload(elementType)
}
else if (isArraySet(code)) {
val List(a1, a2) = args
stackHeight += 1
stack.push(k)
genLoad(a1, INT)
stackHeight += 1
stack.push(INT)
genLoad(a2)
stackHeight -= 2
stack.pop(2)
generatedType = UNIT
bc.astore(elementType)
} else {
Expand Down Expand Up @@ -235,7 +235,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
val resKind = if (hasUnitBranch) UNIT else tpeTK(tree)

val postIf = new asm.Label
genLoadTo(thenp, resKind, LoadDestination.Jump(postIf, stackHeight))
genLoadTo(thenp, resKind, LoadDestination.Jump(postIf, stack.recordSize()))
markProgramPoint(failure)
genLoadTo(elsep, resKind, LoadDestination.FallThrough)
markProgramPoint(postIf)
Expand Down Expand Up @@ -294,8 +294,10 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
)
}

def genLoad(tree: Tree): Unit = {
genLoad(tree, tpeTK(tree))
def genLoad(tree: Tree): BType = {
val generatedType = tpeTK(tree)
genLoad(tree, generatedType)
generatedType
}

/* Generate code for trees that produce values on the stack */
Expand Down Expand Up @@ -364,6 +366,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
case t @ Ident(_) => (t, Nil)
}

val savedStackSize = stack.recordSize()
if (!fun.symbol.isStaticMember) {
// load receiver of non-static implementation of lambda

Expand All @@ -372,10 +375,12 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
// AbstractValidatingLambdaMetafactory.validateMetafactoryArgs

val DesugaredSelect(prefix, _) = fun: @unchecked
genLoad(prefix)
val prefixTK = genLoad(prefix)
stack.push(prefixTK)
}

genLoadArguments(env, fun.symbol.info.firstParamTypes map toTypeKind)
stack.restoreSize(savedStackSize)
generatedType = genInvokeDynamicLambda(NoSymbol, fun.symbol, env.size, functionalInterface)

case app @ Apply(_, _) =>
Expand Down Expand Up @@ -494,9 +499,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
dest match
case LoadDestination.FallThrough =>
()
case LoadDestination.Jump(label, targetStackHeight) =>
if targetStackHeight < stackHeight then
val stackDiff = stackHeight - targetStackHeight
case LoadDestination.Jump(label, targetStackSize) =>
val stackDiff = stack.heightDiffWrt(targetStackSize)
if stackDiff != 0 then
if expectedType == UNIT then
bc dropMany stackDiff
else
Expand Down Expand Up @@ -599,7 +604,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
if dest == LoadDestination.FallThrough then
val resKind = tpeTK(tree)
val jumpTarget = new asm.Label
registerJumpDest(labelSym, resKind, LoadDestination.Jump(jumpTarget, stackHeight))
registerJumpDest(labelSym, resKind, LoadDestination.Jump(jumpTarget, stack.recordSize()))
genLoad(expr, resKind)
markProgramPoint(jumpTarget)
resKind
Expand Down Expand Up @@ -657,7 +662,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
markProgramPoint(loop)

if isInfinite then
val dest = LoadDestination.Jump(loop, stackHeight)
val dest = LoadDestination.Jump(loop, stack.recordSize())
genLoadTo(body, UNIT, dest)
dest
else
Expand All @@ -672,7 +677,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
val failure = new asm.Label
genCond(cond, success, failure, targetIfNoJump = success)
markProgramPoint(success)
genLoadTo(body, UNIT, LoadDestination.Jump(loop, stackHeight))
genLoadTo(body, UNIT, LoadDestination.Jump(loop, stack.recordSize()))
markProgramPoint(failure)
end match
LoadDestination.FallThrough
Expand Down Expand Up @@ -765,10 +770,10 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
// on the stack (contrary to what the type in the AST says).

// scala/bug#10290: qual can be `this.$outer()` (not just `this`), so we call genLoad (not just ALOAD_0)
genLoad(superQual)
stackHeight += 1
val superQualTK = genLoad(superQual)
stack.push(superQualTK)
genLoadArguments(args, paramTKs(app))
stackHeight -= 1
stack.pop()
generatedType = genCallMethod(fun.symbol, InvokeStyle.Super, app.span)

// 'new' constructor call: Note: since constructors are
Expand All @@ -790,9 +795,10 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
assert(classBTypeFromSymbol(ctor.owner) == rt, s"Symbol ${ctor.owner.showFullName} is different from $rt")
mnode.visitTypeInsn(asm.Opcodes.NEW, rt.internalName)
bc dup generatedType
stackHeight += 2
stack.push(rt)
stack.push(rt)
genLoadArguments(args, paramTKs(app))
stackHeight -= 2
stack.pop(2)
genCallMethod(ctor, InvokeStyle.Special, app.span)

case _ =>
Expand Down Expand Up @@ -825,12 +831,11 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
else if (app.hasAttachment(BCodeHelpers.UseInvokeSpecial)) InvokeStyle.Special
else InvokeStyle.Virtual

val savedStackHeight = stackHeight
val savedStackSize = stack.recordSize()
if invokeStyle.hasInstance then
genLoadQualifier(fun)
stackHeight += 1
stack.push(genLoadQualifier(fun))
genLoadArguments(args, paramTKs(app))
stackHeight = savedStackHeight
stack.restoreSize(savedStackSize)

val DesugaredSelect(qual, name) = fun: @unchecked // fun is a Select, also checked in genLoadQualifier
val isArrayClone = name == nme.clone_ && qual.tpe.widen.isInstanceOf[JavaArrayType]
Expand Down Expand Up @@ -888,7 +893,10 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
bc iconst elems.length
bc newarray elmKind

stackHeight += 3 // during the genLoad below, there is the result, its dup, and the index
// during the genLoad below, there is the result, its dup, and the index
stack.push(generatedType)
stack.push(generatedType)
stack.push(INT)

var i = 0
var rest = elems
Expand All @@ -901,7 +909,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
i = i + 1
}

stackHeight -= 3
stack.pop(3)

generatedType
}
Expand All @@ -917,7 +925,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
val (generatedType, postMatch, postMatchDest) =
if dest == LoadDestination.FallThrough then
val postMatch = new asm.Label
(tpeTK(tree), postMatch, LoadDestination.Jump(postMatch, stackHeight))
(tpeTK(tree), postMatch, LoadDestination.Jump(postMatch, stack.recordSize()))
else
(expectedType, null, dest)

Expand Down Expand Up @@ -1179,7 +1187,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
}

/* Emit code to Load the qualifier of `tree` on top of the stack. */
def genLoadQualifier(tree: Tree): Unit = {
def genLoadQualifier(tree: Tree): BType = {
lineNumber(tree)
tree match {
case DesugaredSelect(qualifier, _) => genLoad(qualifier)
Expand All @@ -1188,6 +1196,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
case Some(sel) => genLoadQualifier(sel)
case None =>
assert(t.symbol.owner == this.claszSymbol)
UNIT
}
case _ => abort(s"Unknown qualifier $tree")
}
Expand All @@ -1200,14 +1209,14 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
btpes match
case btpe :: btpes1 =>
genLoad(arg, btpe)
stackHeight += btpe.size
stack.push(btpe)
loop(args1, btpes1)
case _ =>
case _ =>

val savedStackHeight = stackHeight
val savedStackSize = stack.recordSize()
loop(args, btpes)
stackHeight = savedStackHeight
stack.restoreSize(savedStackSize)
end genLoadArguments

def genLoadModule(tree: Tree): BType = {
Expand Down Expand Up @@ -1307,13 +1316,13 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
}.sum
bc.genNewStringBuilder(approxBuilderSize)

stackHeight += 1 // during the genLoad below, there is a reference to the StringBuilder on the stack
stack.push(jlStringBuilderRef) // during the genLoad below, there is a reference to the StringBuilder on the stack
for (elem <- concatArguments) {
val elemType = tpeTK(elem)
genLoad(elem, elemType)
bc.genStringBuilderAppend(elemType)
}
stackHeight -= 1
stack.pop()

bc.genStringBuilderEnd
} else {
Expand All @@ -1331,15 +1340,17 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
var totalArgSlots = 0
var countConcats = 1 // ie. 1 + how many times we spilled

val savedStackHeight = stackHeight
val savedStackSize = stack.recordSize()

for (elem <- concatArguments) {
val tpe = tpeTK(elem)
val elemSlots = tpe.size

// Unlikely spill case
if (totalArgSlots + elemSlots >= MaxIndySlots) {
stackHeight = savedStackHeight + countConcats
stack.restoreSize(savedStackSize)
for _ <- 0 until countConcats do
stack.push(StringRef)
bc.genIndyStringConcat(recipe.toString, argTypes.result(), constVals.result())
countConcats += 1
totalArgSlots = 0
Expand All @@ -1364,10 +1375,10 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
val tpe = tpeTK(elem)
argTypes += tpe.toASMType
genLoad(elem, tpe)
stackHeight += 1
stack.push(tpe)
}
}
stackHeight = savedStackHeight
stack.restoreSize(savedStackSize)
bc.genIndyStringConcat(recipe.toString, argTypes.result(), constVals.result())

// If we spilled, generate one final concat
Expand Down Expand Up @@ -1562,9 +1573,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
} else {
val tk = tpeTK(l).maxType(tpeTK(r))
genLoad(l, tk)
stackHeight += tk.size
stack.push(tk)
genLoad(r, tk)
stackHeight -= tk.size
stack.pop()
genCJUMP(success, failure, op, tk, targetIfNoJump)
}
}
Expand Down Expand Up @@ -1679,9 +1690,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
}

genLoad(l, ObjectRef)
stackHeight += 1
stack.push(ObjectRef)
genLoad(r, ObjectRef)
stackHeight -= 1
stack.pop()
genCallMethod(equalsMethod, InvokeStyle.Static)
genCZJUMP(success, failure, Primitives.NE, BOOL, targetIfNoJump)
}
Expand All @@ -1697,9 +1708,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
} else if (isNonNullExpr(l)) {
// SI-7852 Avoid null check if L is statically non-null.
genLoad(l, ObjectRef)
stackHeight += 1
stack.push(ObjectRef)
genLoad(r, ObjectRef)
stackHeight -= 1
stack.pop()
genCallMethod(defn.Any_equals, InvokeStyle.Virtual)
genCZJUMP(success, failure, Primitives.NE, BOOL, targetIfNoJump)
} else {
Expand All @@ -1709,9 +1720,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
val lNonNull = new asm.Label

genLoad(l, ObjectRef)
stackHeight += 1
stack.push(ObjectRef)
genLoad(r, ObjectRef)
stackHeight -= 1
stack.pop()
locals.store(eqEqTempLocal)
bc dup ObjectRef
genCZJUMP(lNull, lNonNull, Primitives.EQ, ObjectRef, targetIfNoJump = lNull)
Expand Down
Loading

0 comments on commit 1bf0f6d

Please sign in to comment.