Skip to content

Commit

Permalink
Add SUBRK and DIVRK bytecode instructions to bytecode v5 (#1115)
Browse files Browse the repository at this point in the history
Right now, we can compile R\*K for all arithmetic instructions, but K\*R
gets compiled into two instructions (LOADN/LOADK + arithmetic opcode).

This is problematic since it leads to reduced performance for some code.
However, we'd like to avoid adding reverse variants of ADDK et al for
all opcodes to avoid the increase in I$ footprint for interpreter.

Looking at the arithmetic instructions, % and // don't have interesting
use cases for K\*V; ^ is sometimes used with constant on the left hand
side but this would need to call pow() by necessity in all cases so it
would be slow regardless of the dispatch overhead. This leaves the four
basic arithmetic operations.

For + and \*, we can implement a compiler-side optimization in the
future that transforms K\*R to R\*K automatically. This could either be
done unconditionally at -O2, or conditionally based on the type of the
value (driven by type annotations / inference) -- this technically
changes behavior in presence of metamethods, although it might be
sensible to just always do this because non-commutative +/* are evil.

However, for - and / it is impossible for the compiler to optimize this
in the future, so we need dedicated opcodes. This only increases the
interpreter size by ~300 bytes (~1.5%) on X64.

This makes spectral-norm and math-partial-sums 6% faster; maybe more
importantly, voxelgen gets 1.5% faster (so this change does have
real-world impact).

To avoid the proliferation of bytecode versions this change piggybacks
on the bytecode version bump that was just made in 604 for vector
constants; we would still be able to enable these independently but
we'll consider v5 complete when both are enabled.

Related: #626

---------

Co-authored-by: vegorov-rbx <[email protected]>
  • Loading branch information
zeux and vegorov-rbx authored Nov 28, 2023
1 parent 7fb7f43 commit 89b437b
Show file tree
Hide file tree
Showing 13 changed files with 203 additions and 60 deletions.
5 changes: 5 additions & 0 deletions CodeGen/include/Luau/IrVisitUseDef.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i
break;
// A <- B, C
case IrCmd::DO_ARITH:
visitor.maybeUse(inst.b); // Argument can also be a VmConst
visitor.maybeUse(inst.c); // Argument can also be a VmConst

visitor.def(inst.a);
break;
case IrCmd::GET_TABLE:
visitor.use(inst.b);
visitor.maybeUse(inst.c); // Argument can also be a VmConst
Expand Down
4 changes: 2 additions & 2 deletions CodeGen/src/EmitCommonX64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,12 @@ void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, Regi
build.jcc(ConditionX64::NotZero, label);
}

void callArithHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, TMS tm)
void callArithHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, OperandX64 b, OperandX64 c, TMS tm)
{
IrCallWrapperX64 callWrap(regs, build);
callWrap.addArgument(SizeX64::qword, rState);
callWrap.addArgument(SizeX64::qword, luauRegAddress(ra));
callWrap.addArgument(SizeX64::qword, luauRegAddress(rb));
callWrap.addArgument(SizeX64::qword, b);
callWrap.addArgument(SizeX64::qword, c);
callWrap.addArgument(SizeX64::dword, tm);
callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_doarith)]);
Expand Down
2 changes: 1 addition & 1 deletion CodeGen/src/EmitCommonX64.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ ConditionX64 getConditionInt(IrCondition cond);
void getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, RegisterX64 table, int pcpos);
void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, Label& label);

void callArithHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, TMS tm);
void callArithHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, OperandX64 b, OperandX64 c, TMS tm);
void callLengthHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb);
void callGetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra);
void callSetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra);
Expand Down
6 changes: 6 additions & 0 deletions CodeGen/src/IrBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,12 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i)
case LOP_POWK:
translateInstBinaryK(*this, pc, i, TM_POW);
break;
case LOP_SUBRK:
translateInstBinaryRK(*this, pc, i, TM_SUB);
break;
case LOP_DIVRK:
translateInstBinaryRK(*this, pc, i, TM_DIV);
break;
case LOP_NOT:
translateInstNot(*this, pc);
break;
Expand Down
6 changes: 5 additions & 1 deletion CodeGen/src/IrLoweringA64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,11 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
regs.spill(build, index);
build.mov(x0, rState);
build.add(x1, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue)));
build.add(x2, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue)));

if (inst.b.kind == IrOpKind::VmConst)
emitAddOffset(build, x2, rConstants, vmConstOp(inst.b) * sizeof(TValue));
else
build.add(x2, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue)));

if (inst.c.kind == IrOpKind::VmConst)
emitAddOffset(build, x3, rConstants, vmConstOp(inst.c) * sizeof(TValue));
Expand Down
9 changes: 5 additions & 4 deletions CodeGen/src/IrLoweringX64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -962,11 +962,12 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
break;
}
case IrCmd::DO_ARITH:
if (inst.c.kind == IrOpKind::VmReg)
callArithHelper(regs, build, vmRegOp(inst.a), vmRegOp(inst.b), luauRegAddress(vmRegOp(inst.c)), TMS(intOp(inst.d)));
else
callArithHelper(regs, build, vmRegOp(inst.a), vmRegOp(inst.b), luauConstantAddress(vmConstOp(inst.c)), TMS(intOp(inst.d)));
{
OperandX64 opb = inst.b.kind == IrOpKind::VmReg ? luauRegAddress(vmRegOp(inst.b)) : luauConstantAddress(vmConstOp(inst.b));
OperandX64 opc = inst.c.kind == IrOpKind::VmReg ? luauRegAddress(vmRegOp(inst.c)) : luauConstantAddress(vmConstOp(inst.c));
callArithHelper(regs, build, vmRegOp(inst.a), opb, opc, TMS(intOp(inst.d)));
break;
}
case IrCmd::DO_LEN:
callLengthHelper(regs, build, vmRegOp(inst.a), vmRegOp(inst.b));
break;
Expand Down
41 changes: 32 additions & 9 deletions CodeGen/src/IrTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,25 +327,40 @@ void translateInstJumpxEqS(IrBuilder& build, const Instruction* pc, int pcpos)
build.beginBlock(next);
}

static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc, IrOp opc, int pcpos, TMS tm)
static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc, IrOp opb, IrOp opc, int pcpos, TMS tm)
{
IrOp fallback = build.block(IrBlockKind::Fallback);

// fast-path: number
IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb));
build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TNUMBER), fallback);
if (rb != -1)
{
IrOp tb = build.inst(IrCmd::LOAD_TAG, build.vmReg(rb));
build.inst(IrCmd::CHECK_TAG, tb, build.constTag(LUA_TNUMBER), fallback);
}

if (rc != -1 && rc != rb) // TODO: optimization should handle second check, but we'll test it later
{
IrOp tc = build.inst(IrCmd::LOAD_TAG, build.vmReg(rc));
build.inst(IrCmd::CHECK_TAG, tc, build.constTag(LUA_TNUMBER), fallback);
}

IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(rb));
IrOp vc;

IrOp vb, vc;
IrOp result;

if (opb.kind == IrOpKind::VmConst)
{
LUAU_ASSERT(build.function.proto);
TValue protok = build.function.proto->k[vmConstOp(opb)];

LUAU_ASSERT(protok.tt == LUA_TNUMBER);

vb = build.constDouble(protok.value.n);
}
else
{
vb = build.inst(IrCmd::LOAD_DOUBLE, opb);
}

if (opc.kind == IrOpKind::VmConst)
{
LUAU_ASSERT(build.function.proto);
Expand Down Expand Up @@ -409,18 +424,26 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc,
FallbackStreamScope scope(build, fallback, next);

build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1));
build.inst(IrCmd::DO_ARITH, build.vmReg(ra), build.vmReg(rb), opc, build.constInt(tm));
build.inst(IrCmd::DO_ARITH, build.vmReg(ra), opb, opc, build.constInt(tm));
build.inst(IrCmd::JUMP, next);
}

void translateInstBinary(IrBuilder& build, const Instruction* pc, int pcpos, TMS tm)
{
translateInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), LUAU_INSN_C(*pc), build.vmReg(LUAU_INSN_C(*pc)), pcpos, tm);
translateInstBinaryNumeric(
build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), LUAU_INSN_C(*pc), build.vmReg(LUAU_INSN_B(*pc)), build.vmReg(LUAU_INSN_C(*pc)), pcpos, tm);
}

void translateInstBinaryK(IrBuilder& build, const Instruction* pc, int pcpos, TMS tm)
{
translateInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), -1, build.vmConst(LUAU_INSN_C(*pc)), pcpos, tm);
translateInstBinaryNumeric(
build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), -1, build.vmReg(LUAU_INSN_B(*pc)), build.vmConst(LUAU_INSN_C(*pc)), pcpos, tm);
}

void translateInstBinaryRK(IrBuilder& build, const Instruction* pc, int pcpos, TMS tm)
{
translateInstBinaryNumeric(
build, LUAU_INSN_A(*pc), -1, LUAU_INSN_C(*pc), build.vmConst(LUAU_INSN_B(*pc)), build.vmReg(LUAU_INSN_C(*pc)), pcpos, tm);
}

void translateInstNot(IrBuilder& build, const Instruction* pc)
Expand Down
1 change: 1 addition & 0 deletions CodeGen/src/IrTranslation.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ void translateInstJumpxEqN(IrBuilder& build, const Instruction* pc, int pcpos);
void translateInstJumpxEqS(IrBuilder& build, const Instruction* pc, int pcpos);
void translateInstBinary(IrBuilder& build, const Instruction* pc, int pcpos, TMS tm);
void translateInstBinaryK(IrBuilder& build, const Instruction* pc, int pcpos, TMS tm);
void translateInstBinaryRK(IrBuilder& build, const Instruction* pc, int pcpos, TMS tm);
void translateInstNot(IrBuilder& build, const Instruction* pc);
void translateInstMinus(IrBuilder& build, const Instruction* pc, int pcpos);
void translateInstLength(IrBuilder& build, const Instruction* pc, int pcpos);
Expand Down
13 changes: 8 additions & 5 deletions Common/include/Luau/Bytecode.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
// Version 2: Adds Proto::linedefined. Supported until 0.544.
// Version 3: Adds FORGPREP/JUMPXEQK* and enhances AUX encoding for FORGLOOP. Removes FORGLOOP_NEXT/INEXT and JUMPIFEQK/JUMPIFNOTEQK. Currently supported.
// Version 4: Adds Proto::flags, typeinfo, and floor division opcodes IDIV/IDIVK. Currently supported.
// Version 5: Adds vector constants. Currently supported.
// Version 5: Adds SUBRK/DIVRK and vector constants. Currently supported.

// Bytecode opcode, part of the instruction header
enum LuauOpcode
Expand Down Expand Up @@ -219,7 +219,7 @@ enum LuauOpcode
// ADDK, SUBK, MULK, DIVK, MODK, POWK: compute arithmetic operation between the source register and a constant and put the result into target register
// A: target register
// B: source register
// C: constant table index (0..255)
// C: constant table index (0..255); must refer to a number
LOP_ADDK,
LOP_SUBK,
LOP_MULK,
Expand Down Expand Up @@ -348,9 +348,12 @@ enum LuauOpcode
// B: source register (for VAL/REF) or upvalue index (for UPVAL/UPREF)
LOP_CAPTURE,

// removed in v3
LOP_DEP_JUMPIFEQK,
LOP_DEP_JUMPIFNOTEQK,
// SUBRK, DIVRK: compute arithmetic operation between the constant and a source register and put the result into target register
// A: target register
// B: source register
// C: constant table index (0..255); must refer to a number
LOP_SUBRK,
LOP_DIVRK,

// FASTCALL1: perform a fast call of a built-in function using 1 register argument
// A: builtin function id (see LuauBuiltinFunction)
Expand Down
22 changes: 21 additions & 1 deletion Compiler/src/BytecodeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <string.h>

LUAU_FASTFLAG(LuauVectorLiterals)
LUAU_FASTFLAG(LuauCompileRevK)

namespace Luau
{
Expand Down Expand Up @@ -1123,7 +1124,7 @@ std::string BytecodeBuilder::getError(const std::string& message)
uint8_t BytecodeBuilder::getVersion()
{
// This function usually returns LBC_VERSION_TARGET but may sometimes return a higher number (within LBC_VERSION_MIN/MAX) under fast flags
return (FFlag::LuauVectorLiterals ? 5 : LBC_VERSION_TARGET);
return (FFlag::LuauVectorLiterals || FFlag::LuauCompileRevK) ? 5 : LBC_VERSION_TARGET;
}

uint8_t BytecodeBuilder::getTypeEncodingVersion()
Expand Down Expand Up @@ -1351,6 +1352,13 @@ void BytecodeBuilder::validateInstructions() const
VCONST(LUAU_INSN_C(insn), Number);
break;

case LOP_SUBRK:
case LOP_DIVRK:
VREG(LUAU_INSN_A(insn));
VCONST(LUAU_INSN_B(insn), Number);
VREG(LUAU_INSN_C(insn));
break;

case LOP_AND:
case LOP_OR:
VREG(LUAU_INSN_A(insn));
Expand Down Expand Up @@ -1973,6 +1981,18 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result,
result.append("]\n");
break;

case LOP_SUBRK:
formatAppend(result, "SUBRK R%d K%d [", LUAU_INSN_A(insn), LUAU_INSN_B(insn));
dumpConstant(result, LUAU_INSN_B(insn));
formatAppend(result, "] R%d\n", LUAU_INSN_C(insn));
break;

case LOP_DIVRK:
formatAppend(result, "DIVRK R%d K%d [", LUAU_INSN_A(insn), LUAU_INSN_B(insn));
dumpConstant(result, LUAU_INSN_B(insn));
formatAppend(result, "] R%d\n", LUAU_INSN_C(insn));
break;

case LOP_AND:
formatAppend(result, "AND R%d R%d R%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn));
break;
Expand Down
16 changes: 16 additions & 0 deletions Compiler/src/Compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5)
LUAU_FASTFLAGVARIABLE(LuauCompileSideEffects, false)
LUAU_FASTFLAGVARIABLE(LuauCompileDeadIf, false)

LUAU_FASTFLAGVARIABLE(LuauCompileRevK, false)

namespace Luau
{

Expand Down Expand Up @@ -1516,6 +1518,20 @@ struct Compiler
}
else
{
if (FFlag::LuauCompileRevK && (expr->op == AstExprBinary::Sub || expr->op == AstExprBinary::Div))
{
int32_t lc = getConstantNumber(expr->left);

if (lc >= 0 && lc <= 255)
{
uint8_t rr = compileExprAuto(expr->right, rs);
LuauOpcode op = (expr->op == AstExprBinary::Sub) ? LOP_SUBRK : LOP_DIVRK;

bytecode.emitABC(op, target, uint8_t(lc), uint8_t(rr));
return;
}
}

uint8_t rl = compileExprAuto(expr->left, rs);
uint8_t rr = compileExprAuto(expr->right, rs);

Expand Down
57 changes: 47 additions & 10 deletions VM/src/lvmexecute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
VM_DISPATCH_OP(LOP_FORGLOOP), VM_DISPATCH_OP(LOP_FORGPREP_INEXT), VM_DISPATCH_OP(LOP_DEP_FORGLOOP_INEXT), VM_DISPATCH_OP(LOP_FORGPREP_NEXT), \
VM_DISPATCH_OP(LOP_NATIVECALL), VM_DISPATCH_OP(LOP_GETVARARGS), VM_DISPATCH_OP(LOP_DUPCLOSURE), VM_DISPATCH_OP(LOP_PREPVARARGS), \
VM_DISPATCH_OP(LOP_LOADKX), VM_DISPATCH_OP(LOP_JUMPX), VM_DISPATCH_OP(LOP_FASTCALL), VM_DISPATCH_OP(LOP_COVERAGE), \
VM_DISPATCH_OP(LOP_CAPTURE), VM_DISPATCH_OP(LOP_DEP_JUMPIFEQK), VM_DISPATCH_OP(LOP_DEP_JUMPIFNOTEQK), VM_DISPATCH_OP(LOP_FASTCALL1), \
VM_DISPATCH_OP(LOP_CAPTURE), VM_DISPATCH_OP(LOP_SUBRK), VM_DISPATCH_OP(LOP_DIVRK), VM_DISPATCH_OP(LOP_FASTCALL1), \
VM_DISPATCH_OP(LOP_FASTCALL2), VM_DISPATCH_OP(LOP_FASTCALL2K), VM_DISPATCH_OP(LOP_FORGPREP), VM_DISPATCH_OP(LOP_JUMPXEQKNIL), \
VM_DISPATCH_OP(LOP_JUMPXEQKB), VM_DISPATCH_OP(LOP_JUMPXEQKN), VM_DISPATCH_OP(LOP_JUMPXEQKS), VM_DISPATCH_OP(LOP_IDIV), \
VM_DISPATCH_OP(LOP_IDIVK),
Expand Down Expand Up @@ -1858,9 +1858,9 @@ static void luau_execute(lua_State* L)
}
else if (ttisvector(rb))
{
const float* vb = rb->value.v;
float vc = cast_to(float, nvalue(kv));
setvvalue(ra, vb[0] / vc, vb[1] / vc, vb[2] / vc, vb[3] / vc);
const float* vb = vvalue(rb);
float nc = cast_to(float, nvalue(kv));
setvvalue(ra, vb[0] / nc, vb[1] / nc, vb[2] / nc, vb[3] / nc);
VM_NEXT();
}
else
Expand Down Expand Up @@ -2697,16 +2697,53 @@ static void luau_execute(lua_State* L)
LUAU_UNREACHABLE();
}

VM_CASE(LOP_DEP_JUMPIFEQK)
VM_CASE(LOP_SUBRK)
{
LUAU_ASSERT(!"Unsupported deprecated opcode");
LUAU_UNREACHABLE();
Instruction insn = *pc++;
StkId ra = VM_REG(LUAU_INSN_A(insn));
TValue* kv = VM_KV(LUAU_INSN_B(insn));
StkId rc = VM_REG(LUAU_INSN_C(insn));

// fast-path
if (ttisnumber(rc))
{
setnvalue(ra, nvalue(kv) - nvalue(rc));
VM_NEXT();
}
else
{
// slow-path, may invoke C/Lua via metamethods
VM_PROTECT(luaV_doarith(L, ra, kv, rc, TM_SUB));
VM_NEXT();
}
}

VM_CASE(LOP_DEP_JUMPIFNOTEQK)
VM_CASE(LOP_DIVRK)
{
LUAU_ASSERT(!"Unsupported deprecated opcode");
LUAU_UNREACHABLE();
Instruction insn = *pc++;
StkId ra = VM_REG(LUAU_INSN_A(insn));
TValue* kv = VM_KV(LUAU_INSN_B(insn));
StkId rc = VM_REG(LUAU_INSN_C(insn));

// fast-path
if (LUAU_LIKELY(ttisnumber(rc)))
{
setnvalue(ra, nvalue(kv) / nvalue(rc));
VM_NEXT();
}
else if (ttisvector(rc))
{
float nb = cast_to(float, nvalue(kv));
const float* vc = vvalue(rc);
setvvalue(ra, nb / vc[0], nb / vc[1], nb / vc[2], nb / vc[3]);
VM_NEXT();
}
else
{
// slow-path, may invoke C/Lua via metamethods
VM_PROTECT(luaV_doarith(L, ra, kv, rc, TM_DIV));
VM_NEXT();
}
}

VM_CASE(LOP_FASTCALL1)
Expand Down
Loading

0 comments on commit 89b437b

Please sign in to comment.