diff --git a/lib/Target/AArch64/AArch64InstrInfo.cpp b/lib/Target/AArch64/AArch64InstrInfo.cpp index df883d35fa1..0b2c04bea76 100644 --- a/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -2426,20 +2426,34 @@ bool AArch64InstrInfo::hasPattern( static MachineInstr *genMadd(MachineFunction &MF, MachineRegisterInfo &MRI, const TargetInstrInfo *TII, MachineInstr &Root, SmallVectorImpl &InsInstrs, - unsigned IdxMulOpd, unsigned MaddOpc) { + unsigned IdxMulOpd, unsigned MaddOpc, + const TargetRegisterClass *RC) { assert(IdxMulOpd == 1 || IdxMulOpd == 2); unsigned IdxOtherOpd = IdxMulOpd == 1 ? 2 : 1; MachineInstr *MUL = MRI.getUniqueVRegDef(Root.getOperand(IdxMulOpd).getReg()); - MachineOperand R = Root.getOperand(0); - MachineOperand A = MUL->getOperand(1); - MachineOperand B = MUL->getOperand(2); - MachineOperand C = Root.getOperand(IdxOtherOpd); - MachineInstrBuilder MIB = BuildMI(MF, Root.getDebugLoc(), TII->get(MaddOpc)) - .addOperand(R) - .addOperand(A) - .addOperand(B) - .addOperand(C); + unsigned ResultReg = Root.getOperand(0).getReg(); + unsigned SrcReg0 = MUL->getOperand(1).getReg(); + bool Src0IsKill = MUL->getOperand(1).isKill(); + unsigned SrcReg1 = MUL->getOperand(2).getReg(); + bool Src1IsKill = MUL->getOperand(2).isKill(); + unsigned SrcReg2 = Root.getOperand(IdxOtherOpd).getReg(); + bool Src2IsKill = Root.getOperand(IdxOtherOpd).isKill(); + + if (TargetRegisterInfo::isVirtualRegister(ResultReg)) + MRI.constrainRegClass(ResultReg, RC); + if (TargetRegisterInfo::isVirtualRegister(SrcReg0)) + MRI.constrainRegClass(SrcReg0, RC); + if (TargetRegisterInfo::isVirtualRegister(SrcReg1)) + MRI.constrainRegClass(SrcReg1, RC); + if (TargetRegisterInfo::isVirtualRegister(SrcReg2)) + MRI.constrainRegClass(SrcReg2, RC); + + MachineInstrBuilder MIB = BuildMI(MF, Root.getDebugLoc(), TII->get(MaddOpc), + ResultReg) + .addReg(SrcReg0, getKillRegState(Src0IsKill)) + .addReg(SrcReg1, getKillRegState(Src1IsKill)) + .addReg(SrcReg2, getKillRegState(Src2IsKill)); // Insert the MADD InsInstrs.push_back(MIB); return MUL; @@ -2464,22 +2478,35 @@ static MachineInstr *genMaddR(MachineFunction &MF, MachineRegisterInfo &MRI, const TargetInstrInfo *TII, MachineInstr &Root, SmallVectorImpl &InsInstrs, unsigned IdxMulOpd, unsigned MaddOpc, - unsigned VR) { + unsigned VR, const TargetRegisterClass *RC) { assert(IdxMulOpd == 1 || IdxMulOpd == 2); MachineInstr *MUL = MRI.getUniqueVRegDef(Root.getOperand(IdxMulOpd).getReg()); - MachineOperand R = Root.getOperand(0); - MachineOperand A = MUL->getOperand(1); - MachineOperand B = MUL->getOperand(2); - MachineInstrBuilder MIB = BuildMI(MF, Root.getDebugLoc(), TII->get(MaddOpc)) - .addOperand(R) - .addOperand(A) - .addOperand(B) + unsigned ResultReg = Root.getOperand(0).getReg(); + unsigned SrcReg0 = MUL->getOperand(1).getReg(); + bool Src0IsKill = MUL->getOperand(1).isKill(); + unsigned SrcReg1 = MUL->getOperand(2).getReg(); + bool Src1IsKill = MUL->getOperand(2).isKill(); + + if (TargetRegisterInfo::isVirtualRegister(ResultReg)) + MRI.constrainRegClass(ResultReg, RC); + if (TargetRegisterInfo::isVirtualRegister(SrcReg0)) + MRI.constrainRegClass(SrcReg0, RC); + if (TargetRegisterInfo::isVirtualRegister(SrcReg1)) + MRI.constrainRegClass(SrcReg1, RC); + if (TargetRegisterInfo::isVirtualRegister(VR)) + MRI.constrainRegClass(VR, RC); + + MachineInstrBuilder MIB = BuildMI(MF, Root.getDebugLoc(), TII->get(MaddOpc), + ResultReg) + .addReg(SrcReg0, getKillRegState(Src0IsKill)) + .addReg(SrcReg1, getKillRegState(Src1IsKill)) .addReg(VR); // Insert the MADD InsInstrs.push_back(MIB); return MUL; } + /// genAlternativeCodeSequence - when hasPattern() finds a pattern /// this function generates the instructions that could replace the /// original code sequence @@ -2494,6 +2521,7 @@ void AArch64InstrInfo::genAlternativeCodeSequence( const TargetInstrInfo *TII = MF.getTarget().getSubtargetImpl()->getInstrInfo(); MachineInstr *MUL; + const TargetRegisterClass *RC = nullptr; unsigned Opc; switch (Pattern) { default: @@ -2507,7 +2535,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence( // --- Create(MADD); Opc = Pattern == MachineCombinerPattern::MC_MULADDW_OP1 ? AArch64::MADDWrrr : AArch64::MADDXrrr; - MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 1, Opc); + if (Pattern == MachineCombinerPattern::MC_MULADDW_OP1) + RC = &AArch64::GPR32RegClass; + else + RC = &AArch64::GPR64RegClass; + MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); break; case MachineCombinerPattern::MC_MULADDW_OP2: case MachineCombinerPattern::MC_MULADDX_OP2: @@ -2517,52 +2549,56 @@ void AArch64InstrInfo::genAlternativeCodeSequence( // --- Create(MADD); Opc = Pattern == MachineCombinerPattern::MC_MULADDW_OP2 ? AArch64::MADDWrrr : AArch64::MADDXrrr; - MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 2, Opc); + if (Pattern == MachineCombinerPattern::MC_MULADDW_OP2) + RC = &AArch64::GPR32RegClass; + else + RC = &AArch64::GPR64RegClass; + MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); break; case MachineCombinerPattern::MC_MULADDWI_OP1: - case MachineCombinerPattern::MC_MULADDXI_OP1: + case MachineCombinerPattern::MC_MULADDXI_OP1: { // MUL I=A,B,0 // ADD R,I,Imm // ==> ORR V, ZR, Imm // ==> MADD R,A,B,V // --- Create(MADD); - { - const TargetRegisterClass *RC = - MRI.getRegClass(Root.getOperand(1).getReg()); - unsigned NewVR = MRI.createVirtualRegister(RC); - unsigned BitSize, OrrOpc, ZeroReg; - if (Pattern == MachineCombinerPattern::MC_MULADDWI_OP1) { - BitSize = 32; - OrrOpc = AArch64::ORRWri; - ZeroReg = AArch64::WZR; - Opc = AArch64::MADDWrrr; - } else { - OrrOpc = AArch64::ORRXri; - BitSize = 64; - ZeroReg = AArch64::XZR; - Opc = AArch64::MADDXrrr; - } - uint64_t Imm = Root.getOperand(2).getImm(); + const TargetRegisterClass *OrrRC = nullptr; + unsigned BitSize, OrrOpc, ZeroReg; + if (Pattern == MachineCombinerPattern::MC_MULADDWI_OP1) { + OrrOpc = AArch64::ORRWri; + OrrRC = &AArch64::GPR32spRegClass; + BitSize = 32; + ZeroReg = AArch64::WZR; + Opc = AArch64::MADDWrrr; + RC = &AArch64::GPR32RegClass; + } else { + OrrOpc = AArch64::ORRXri; + OrrRC = &AArch64::GPR64spRegClass; + BitSize = 64; + ZeroReg = AArch64::XZR; + Opc = AArch64::MADDXrrr; + RC = &AArch64::GPR64RegClass; + } + unsigned NewVR = MRI.createVirtualRegister(OrrRC); + uint64_t Imm = Root.getOperand(2).getImm(); - if (Root.getOperand(3).isImm()) { - unsigned val = Root.getOperand(3).getImm(); - Imm = Imm << val; - } - uint64_t UImm = Imm << (64 - BitSize) >> (64 - BitSize); - uint64_t Encoding; - - if (AArch64_AM::processLogicalImmediate(UImm, BitSize, Encoding)) { - MachineInstrBuilder MIB1 = - BuildMI(MF, Root.getDebugLoc(), TII->get(OrrOpc)) - .addOperand(MachineOperand::CreateReg(NewVR, true)) - .addReg(ZeroReg) - .addImm(Encoding); - InsInstrs.push_back(MIB1); - InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0)); - MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR); - } + if (Root.getOperand(3).isImm()) { + unsigned Val = Root.getOperand(3).getImm(); + Imm = Imm << Val; + } + uint64_t UImm = Imm << (64 - BitSize) >> (64 - BitSize); + uint64_t Encoding; + if (AArch64_AM::processLogicalImmediate(UImm, BitSize, Encoding)) { + MachineInstrBuilder MIB1 = + BuildMI(MF, Root.getDebugLoc(), TII->get(OrrOpc), NewVR) + .addReg(ZeroReg) + .addImm(Encoding); + InsInstrs.push_back(MIB1); + InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0)); + MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR, RC); } break; + } case MachineCombinerPattern::MC_MULSUBW_OP1: case MachineCombinerPattern::MC_MULSUBX_OP1: { // MUL I=A,B,0 @@ -2570,29 +2606,32 @@ void AArch64InstrInfo::genAlternativeCodeSequence( // ==> SUB V, 0, C // ==> MADD R,A,B,V // = -C + A*B // --- Create(MADD); - const TargetRegisterClass *RC = - MRI.getRegClass(Root.getOperand(1).getReg()); - unsigned NewVR = MRI.createVirtualRegister(RC); + const TargetRegisterClass *SubRC = nullptr; unsigned SubOpc, ZeroReg; if (Pattern == MachineCombinerPattern::MC_MULSUBW_OP1) { SubOpc = AArch64::SUBWrr; + SubRC = &AArch64::GPR32spRegClass; ZeroReg = AArch64::WZR; Opc = AArch64::MADDWrrr; + RC = &AArch64::GPR32RegClass; } else { SubOpc = AArch64::SUBXrr; + SubRC = &AArch64::GPR64spRegClass; ZeroReg = AArch64::XZR; Opc = AArch64::MADDXrrr; + RC = &AArch64::GPR64RegClass; } + unsigned NewVR = MRI.createVirtualRegister(SubRC); // SUB NewVR, 0, C MachineInstrBuilder MIB1 = - BuildMI(MF, Root.getDebugLoc(), TII->get(SubOpc)) - .addOperand(MachineOperand::CreateReg(NewVR, true)) + BuildMI(MF, Root.getDebugLoc(), TII->get(SubOpc), NewVR) .addReg(ZeroReg) .addOperand(Root.getOperand(2)); InsInstrs.push_back(MIB1); InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0)); - MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR); - } break; + MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR, RC); + break; + } case MachineCombinerPattern::MC_MULSUBW_OP2: case MachineCombinerPattern::MC_MULSUBX_OP2: // MUL I=A,B,0 @@ -2601,7 +2640,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence( // --- Create(MSUB); Opc = Pattern == MachineCombinerPattern::MC_MULSUBW_OP2 ? AArch64::MSUBWrrr : AArch64::MSUBXrrr; - MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 2, Opc); + if (Pattern == MachineCombinerPattern::MC_MULSUBW_OP2) + RC = &AArch64::GPR32RegClass; + else + RC = &AArch64::GPR64RegClass; + MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); break; case MachineCombinerPattern::MC_MULSUBWI_OP1: case MachineCombinerPattern::MC_MULSUBXI_OP1: { @@ -2610,40 +2653,43 @@ void AArch64InstrInfo::genAlternativeCodeSequence( // ==> ORR V, ZR, -Imm // ==> MADD R,A,B,V // = -Imm + A*B // --- Create(MADD); - const TargetRegisterClass *RC = - MRI.getRegClass(Root.getOperand(1).getReg()); - unsigned NewVR = MRI.createVirtualRegister(RC); + const TargetRegisterClass *OrrRC = nullptr; unsigned BitSize, OrrOpc, ZeroReg; if (Pattern == MachineCombinerPattern::MC_MULSUBWI_OP1) { - BitSize = 32; OrrOpc = AArch64::ORRWri; + RC = &AArch64::GPR32spRegClass; + BitSize = 32; ZeroReg = AArch64::WZR; Opc = AArch64::MADDWrrr; + RC = &AArch64::GPR32RegClass; } else { OrrOpc = AArch64::ORRXri; + RC = &AArch64::GPR64RegClass; BitSize = 64; ZeroReg = AArch64::XZR; Opc = AArch64::MADDXrrr; + RC = &AArch64::GPR64RegClass; } + unsigned NewVR = MRI.createVirtualRegister(OrrRC); int Imm = Root.getOperand(2).getImm(); if (Root.getOperand(3).isImm()) { - unsigned val = Root.getOperand(3).getImm(); - Imm = Imm << val; + unsigned Val = Root.getOperand(3).getImm(); + Imm = Imm << Val; } uint64_t UImm = -Imm << (64 - BitSize) >> (64 - BitSize); uint64_t Encoding; if (AArch64_AM::processLogicalImmediate(UImm, BitSize, Encoding)) { MachineInstrBuilder MIB1 = - BuildMI(MF, Root.getDebugLoc(), TII->get(OrrOpc)) - .addOperand(MachineOperand::CreateReg(NewVR, true)) + BuildMI(MF, Root.getDebugLoc(), TII->get(OrrOpc), NewVR) .addReg(ZeroReg) .addImm(Encoding); InsInstrs.push_back(MIB1); InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0)); - MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR); + MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR, RC); } - } break; + break; } + } // end switch (Pattern) // Record MUL and ADD/SUB for deletion DelInstrs.push_back(MUL); DelInstrs.push_back(&Root); diff --git a/test/CodeGen/AArch64/madd-combiner.ll b/test/CodeGen/AArch64/madd-combiner.ll new file mode 100644 index 00000000000..e7f171b9048 --- /dev/null +++ b/test/CodeGen/AArch64/madd-combiner.ll @@ -0,0 +1,20 @@ +; RUN: llc -mtriple=aarch64-apple-darwin -verify-machineinstrs < %s | FileCheck %s + +; Test that we use the correct register class. +define i32 @mul_add_imm(i32 %a, i32 %b) { +; CHECK-LABEL: mul_add_imm +; CHECK: orr [[REG:w[0-9]+]], wzr, #0x4 +; CHECK-NEXT: madd {{w[0-9]+}}, w0, w1, [[REG]] + %1 = mul i32 %a, %b + %2 = add i32 %1, 4 + ret i32 %2 +} + +define i32 @mul_sub_imm1(i32 %a, i32 %b) { +; CHECK-LABEL: mul_sub_imm1 +; CHECK: orr [[REG:w[0-9]+]], wzr, #0x4 +; CHECK-NEXT: msub {{w[0-9]+}}, w0, w1, [[REG]] + %1 = mul i32 %a, %b + %2 = sub i32 4, %1 + ret i32 %2 +}