From 73d6c3b382e8e0e4268ac68a2aa4f82c788e788f Mon Sep 17 00:00:00 2001 From: Takeshi Yoneda Date: Thu, 16 Nov 2023 14:38:58 +0900 Subject: [PATCH] wazevo(arm64): fixes Vector select (#1840) Signed-off-by: Takeshi Yoneda --- .../engine/wazevo/backend/isa/arm64/instr.go | 11 ++++-- .../backend/isa/arm64/instr_encoding.go | 12 +++++-- .../backend/isa/arm64/instr_encoding_test.go | 6 ++-- .../wazevo/backend/isa/arm64/lower_instr.go | 35 ++++++++----------- .../backend/isa/arm64/lower_instr_test.go | 7 ++-- .../fuzzcases/fuzzcases_test.go | 2 ++ 6 files changed, 41 insertions(+), 32 deletions(-) diff --git a/internal/engine/wazevo/backend/isa/arm64/instr.go b/internal/engine/wazevo/backend/isa/arm64/instr.go index 44e1e99836..ee364736ae 100644 --- a/internal/engine/wazevo/backend/isa/arm64/instr.go +++ b/internal/engine/wazevo/backend/isa/arm64/instr.go @@ -586,10 +586,13 @@ func (i *instruction) asVecLoad1R(rd, rn operand, arr vecArrangement) { i.u1 = uint64(arr) } -func (i *instruction) asCSet(rd regalloc.VReg, c condFlag) { +func (i *instruction) asCSet(rd regalloc.VReg, mask bool, c condFlag) { i.kind = cSet i.rd = operandNR(rd) i.u1 = uint64(c) + if mask { + i.u2 = 1 + } } func (i *instruction) asCSel(rd, rn, rm operand, c condFlag, _64bit bool) { @@ -1092,7 +1095,11 @@ func (i *instruction) String() (str string) { condFlag(i.u1), ) case cSet: - str = fmt.Sprintf("cset %s, %s", formatVRegSized(i.rd.nr(), 64), condFlag(i.u1)) + if i.u2 != 0 { + str = fmt.Sprintf("csetm %s, %s", formatVRegSized(i.rd.nr(), 64), condFlag(i.u1)) + } else { + str = fmt.Sprintf("cset %s, %s", formatVRegSized(i.rd.nr(), 64), condFlag(i.u1)) + } case cCmpImm: size := is64SizeBitToSize(i.u3) str = fmt.Sprintf("ccmp %s, #%#x, #%#x, %s", diff --git a/internal/engine/wazevo/backend/isa/arm64/instr_encoding.go b/internal/engine/wazevo/backend/isa/arm64/instr_encoding.go index 13fe4459e7..ef5a6f1265 100644 --- a/internal/engine/wazevo/backend/isa/arm64/instr_encoding.go +++ b/internal/engine/wazevo/backend/isa/arm64/instr_encoding.go @@ -214,9 +214,15 @@ func (i *instruction) encode(c backend.Compiler) { case cSet: rd := regNumberInEncoding[i.rd.realReg()] cf := condFlag(i.u1) - // https://developer.arm.com/documentation/ddi0602/2022-06/Base-Instructions/CSET--Conditional-Set--an-alias-of-CSINC- - // Note that we set 64bit version here. - c.Emit4Bytes(0b1001101010011111<<16 | uint32(cf.invert())<<12 | 0b111111<<5 | rd) + if i.u2 == 1 { + // https://developer.arm.com/documentation/ddi0602/2022-03/Base-Instructions/CSETM--Conditional-Set-Mask--an-alias-of-CSINV- + // Note that we set 64bit version here. + c.Emit4Bytes(0b1101101010011111<<16 | uint32(cf.invert())<<12 | 0b011111<<5 | rd) + } else { + // https://developer.arm.com/documentation/ddi0602/2022-06/Base-Instructions/CSET--Conditional-Set--an-alias-of-CSINC- + // Note that we set 64bit version here. + c.Emit4Bytes(0b1001101010011111<<16 | uint32(cf.invert())<<12 | 0b111111<<5 | rd) + } case extend: c.Emit4Bytes(encodeExtend(i.u3 == 1, byte(i.u1), byte(i.u2), regNumberInEncoding[i.rd.realReg()], regNumberInEncoding[i.rn.realReg()])) case fpuCmp: diff --git a/internal/engine/wazevo/backend/isa/arm64/instr_encoding_test.go b/internal/engine/wazevo/backend/isa/arm64/instr_encoding_test.go index d943906d6d..20265a940b 100644 --- a/internal/engine/wazevo/backend/isa/arm64/instr_encoding_test.go +++ b/internal/engine/wazevo/backend/isa/arm64/instr_encoding_test.go @@ -1098,8 +1098,10 @@ func TestInstruction_encode(t *testing.T) { {want: "b21c4093", setup: func(i *instruction) { i.asExtend(x18VReg, x5VReg, 8, 64, true) }}, {want: "b23c4093", setup: func(i *instruction) { i.asExtend(x18VReg, x5VReg, 16, 64, true) }}, {want: "b27c4093", setup: func(i *instruction) { i.asExtend(x18VReg, x5VReg, 32, 64, true) }}, - {want: "f2079f9a", setup: func(i *instruction) { i.asCSet(x18VReg, ne) }}, - {want: "f2179f9a", setup: func(i *instruction) { i.asCSet(x18VReg, eq) }}, + {want: "f2079f9a", setup: func(i *instruction) { i.asCSet(x18VReg, false, ne) }}, + {want: "f2179f9a", setup: func(i *instruction) { i.asCSet(x18VReg, false, eq) }}, + {want: "e0039fda", setup: func(i *instruction) { i.asCSet(x0VReg, true, ne) }}, + {want: "f2139fda", setup: func(i *instruction) { i.asCSet(x18VReg, true, eq) }}, {want: "32008012", setup: func(i *instruction) { i.asMOVN(x18VReg, 1, 0, false) }}, {want: "52559512", setup: func(i *instruction) { i.asMOVN(x18VReg, 0xaaaa, 0, false) }}, {want: "f2ff9f12", setup: func(i *instruction) { i.asMOVN(x18VReg, 0xffff, 0, false) }}, diff --git a/internal/engine/wazevo/backend/isa/arm64/lower_instr.go b/internal/engine/wazevo/backend/isa/arm64/lower_instr.go index 9d335cb797..5c70ee6d13 100644 --- a/internal/engine/wazevo/backend/isa/arm64/lower_instr.go +++ b/internal/engine/wazevo/backend/isa/arm64/lower_instr.go @@ -809,7 +809,7 @@ func (m *machine) lowerVcheckTrue(op ssa.Opcode, rm, rd operand, arr vecArrangem m.insert(fcmp) cset := m.allocateInstr() - cset.asCSet(rd.nr(), eq) + cset.asCSet(rd.nr(), false, eq) m.insert(cset) return @@ -840,7 +840,7 @@ func (m *machine) lowerVcheckTrue(op ssa.Opcode, rm, rd operand, arr vecArrangem m.insert(fc) cset := m.allocateInstr() - cset.asCSet(rd.nr(), ne) + cset.asCSet(rd.nr(), false, ne) m.insert(cset) } @@ -1438,7 +1438,7 @@ func (m *machine) lowerIcmp(si *ssa.Instruction) { m.insert(alu) cset := m.allocateInstr() - cset.asCSet(m.compiler.VRegOf(si.Return()), flag) + cset.asCSet(m.compiler.VRegOf(si.Return()), false, flag) m.insert(cset) } @@ -1654,7 +1654,7 @@ func (m *machine) lowerFcmp(x, y, result ssa.Value, c ssa.FloatCmpCond) { m.insert(fc) cset := m.allocateInstr() - cset.asCSet(m.compiler.VRegOf(result), condFlagFromSSAFloatCmpCond(c)) + cset.asCSet(m.compiler.VRegOf(result), false, condFlagFromSSAFloatCmpCond(c)) m.insert(cset) } @@ -1906,27 +1906,20 @@ func (m *machine) lowerSelect(c, x, y, result ssa.Value) { } func (m *machine) lowerSelectVec(rc, rn, rm, rd operand) { - // First, we copy the condition to a temporary register in case rc is used somewhere else. - tmp := m.compiler.AllocateVReg(ssa.TypeI32) - mov := m.allocateInstr() - mov.asMove32(tmp, rc.nr()) - m.insert(mov) - - // Next is to clear the unnecessary bits of rc by ANDing it with 1, and store it to a temporary register. - oneOrZero := m.compiler.AllocateVReg(ssa.TypeI32) - and := m.allocateInstr() - and.asALUBitmaskImm(aluOpAnd, oneOrZero, tmp, 1, false) - m.insert(and) + // First check if `rc` is zero or not. + checkZero := m.allocateInstr() + checkZero.asALU(aluOpSubS, operandNR(xzrVReg), rc, operandNR(xzrVReg), false) + m.insert(checkZero) - // Sets all bits to 1 if rc is not zero. - allOneOrZero := operandNR(m.compiler.AllocateVReg(ssa.TypeI64)) - alu := m.allocateInstr() - alu.asALU(aluOpSub, allOneOrZero, operandNR(xzrVReg), operandNR(oneOrZero), true) - m.insert(alu) + // Then use CSETM to set all bits to one if `rc` is zero. + allOnesOrZero := m.compiler.AllocateVReg(ssa.TypeI64) + cset := m.allocateInstr() + cset.asCSet(allOnesOrZero, true, ne) + m.insert(cset) // Then move the bits to the result vector register. dup := m.allocateInstr() - dup.asVecDup(rd, allOneOrZero, vecArrangement2D) + dup.asVecDup(rd, operandNR(allOnesOrZero), vecArrangement2D) m.insert(dup) // Now that `rd` has either all bits one or zero depending on `rc`, diff --git a/internal/engine/wazevo/backend/isa/arm64/lower_instr_test.go b/internal/engine/wazevo/backend/isa/arm64/lower_instr_test.go index aa43c75cb3..2f96934116 100644 --- a/internal/engine/wazevo/backend/isa/arm64/lower_instr_test.go +++ b/internal/engine/wazevo/backend/isa/arm64/lower_instr_test.go @@ -853,10 +853,9 @@ func TestMachine_lowerSelectVec(t *testing.T) { m.lowerSelectVec(c, rn, rm, rd) require.Equal(t, ` -mov w5?, w1? -and w6?, w5?, #0x1 -sub x7?, xzr, x6? -dup v4?.2d, x7? +subs wzr, w1?, wzr +csetm x5?, ne +dup v4?.2d, x5? bsl v4?.16b, v2?.16b, v3?.16b `, "\n"+formatEmittedInstructionsInCurrentBlock(m)+"\n") } diff --git a/internal/integration_test/fuzzcases/fuzzcases_test.go b/internal/integration_test/fuzzcases/fuzzcases_test.go index ac6d4cb80a..0731e39b61 100644 --- a/internal/integration_test/fuzzcases/fuzzcases_test.go +++ b/internal/integration_test/fuzzcases/fuzzcases_test.go @@ -92,9 +92,11 @@ func Test696(t *testing.T) { in uint64 exp [2]uint64 }{ + {fnName: "select", in: 1 << 5, exp: [2]uint64{0xffffffffffffffff, 0xeeeeeeeeeeeeeeee}}, {fnName: "select", in: 1, exp: [2]uint64{0xffffffffffffffff, 0xeeeeeeeeeeeeeeee}}, {fnName: "select", in: 0, exp: [2]uint64{0x1111111111111111, 0x2222222222222222}}, {fnName: "select", in: 0xffffff, exp: [2]uint64{0xffffffffffffffff, 0xeeeeeeeeeeeeeeee}}, + {fnName: "select", in: 0xffff00, exp: [2]uint64{0xffffffffffffffff, 0xeeeeeeeeeeeeeeee}}, {fnName: "select", in: 0x000000, exp: [2]uint64{0x1111111111111111, 0x2222222222222222}}, {fnName: "typed select", in: 1, exp: [2]uint64{0xffffffffffffffff, 0xeeeeeeeeeeeeeeee}}, {fnName: "typed select", in: 0, exp: [2]uint64{0x1111111111111111, 0x2222222222222222}},