[MachineCombiner][AArch64] Use the correct register class for MADD, SUB, and OR.

Select the correct register class for the various instructions that are
generated when combining instructions and constrain the registers to the
appropriate register class.

This fixes rdar://problem/18183707.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@216805 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Juergen Ributzka 2014-08-29 23:48:09 +00:00
parent e7f301e079
commit 4e92383b67
2 changed files with 139 additions and 73 deletions

View File

@ -2426,20 +2426,34 @@ bool AArch64InstrInfo::hasPattern(
static MachineInstr *genMadd(MachineFunction &MF, MachineRegisterInfo &MRI, static MachineInstr *genMadd(MachineFunction &MF, MachineRegisterInfo &MRI,
const TargetInstrInfo *TII, MachineInstr &Root, const TargetInstrInfo *TII, MachineInstr &Root,
SmallVectorImpl<MachineInstr *> &InsInstrs, SmallVectorImpl<MachineInstr *> &InsInstrs,
unsigned IdxMulOpd, unsigned MaddOpc) { unsigned IdxMulOpd, unsigned MaddOpc,
const TargetRegisterClass *RC) {
assert(IdxMulOpd == 1 || IdxMulOpd == 2); assert(IdxMulOpd == 1 || IdxMulOpd == 2);
unsigned IdxOtherOpd = IdxMulOpd == 1 ? 2 : 1; unsigned IdxOtherOpd = IdxMulOpd == 1 ? 2 : 1;
MachineInstr *MUL = MRI.getUniqueVRegDef(Root.getOperand(IdxMulOpd).getReg()); MachineInstr *MUL = MRI.getUniqueVRegDef(Root.getOperand(IdxMulOpd).getReg());
MachineOperand R = Root.getOperand(0); unsigned ResultReg = Root.getOperand(0).getReg();
MachineOperand A = MUL->getOperand(1); unsigned SrcReg0 = MUL->getOperand(1).getReg();
MachineOperand B = MUL->getOperand(2); bool Src0IsKill = MUL->getOperand(1).isKill();
MachineOperand C = Root.getOperand(IdxOtherOpd); unsigned SrcReg1 = MUL->getOperand(2).getReg();
MachineInstrBuilder MIB = BuildMI(MF, Root.getDebugLoc(), TII->get(MaddOpc)) bool Src1IsKill = MUL->getOperand(2).isKill();
.addOperand(R) unsigned SrcReg2 = Root.getOperand(IdxOtherOpd).getReg();
.addOperand(A) bool Src2IsKill = Root.getOperand(IdxOtherOpd).isKill();
.addOperand(B)
.addOperand(C); if (TargetRegisterInfo::isVirtualRegister(ResultReg))
MRI.constrainRegClass(ResultReg, RC);
if (TargetRegisterInfo::isVirtualRegister(SrcReg0))
MRI.constrainRegClass(SrcReg0, RC);
if (TargetRegisterInfo::isVirtualRegister(SrcReg1))
MRI.constrainRegClass(SrcReg1, RC);
if (TargetRegisterInfo::isVirtualRegister(SrcReg2))
MRI.constrainRegClass(SrcReg2, RC);
MachineInstrBuilder MIB = BuildMI(MF, Root.getDebugLoc(), TII->get(MaddOpc),
ResultReg)
.addReg(SrcReg0, getKillRegState(Src0IsKill))
.addReg(SrcReg1, getKillRegState(Src1IsKill))
.addReg(SrcReg2, getKillRegState(Src2IsKill));
// Insert the MADD // Insert the MADD
InsInstrs.push_back(MIB); InsInstrs.push_back(MIB);
return MUL; return MUL;
@ -2464,22 +2478,35 @@ static MachineInstr *genMaddR(MachineFunction &MF, MachineRegisterInfo &MRI,
const TargetInstrInfo *TII, MachineInstr &Root, const TargetInstrInfo *TII, MachineInstr &Root,
SmallVectorImpl<MachineInstr *> &InsInstrs, SmallVectorImpl<MachineInstr *> &InsInstrs,
unsigned IdxMulOpd, unsigned MaddOpc, unsigned IdxMulOpd, unsigned MaddOpc,
unsigned VR) { unsigned VR, const TargetRegisterClass *RC) {
assert(IdxMulOpd == 1 || IdxMulOpd == 2); assert(IdxMulOpd == 1 || IdxMulOpd == 2);
MachineInstr *MUL = MRI.getUniqueVRegDef(Root.getOperand(IdxMulOpd).getReg()); MachineInstr *MUL = MRI.getUniqueVRegDef(Root.getOperand(IdxMulOpd).getReg());
MachineOperand R = Root.getOperand(0); unsigned ResultReg = Root.getOperand(0).getReg();
MachineOperand A = MUL->getOperand(1); unsigned SrcReg0 = MUL->getOperand(1).getReg();
MachineOperand B = MUL->getOperand(2); bool Src0IsKill = MUL->getOperand(1).isKill();
MachineInstrBuilder MIB = BuildMI(MF, Root.getDebugLoc(), TII->get(MaddOpc)) unsigned SrcReg1 = MUL->getOperand(2).getReg();
.addOperand(R) bool Src1IsKill = MUL->getOperand(2).isKill();
.addOperand(A)
.addOperand(B) if (TargetRegisterInfo::isVirtualRegister(ResultReg))
MRI.constrainRegClass(ResultReg, RC);
if (TargetRegisterInfo::isVirtualRegister(SrcReg0))
MRI.constrainRegClass(SrcReg0, RC);
if (TargetRegisterInfo::isVirtualRegister(SrcReg1))
MRI.constrainRegClass(SrcReg1, RC);
if (TargetRegisterInfo::isVirtualRegister(VR))
MRI.constrainRegClass(VR, RC);
MachineInstrBuilder MIB = BuildMI(MF, Root.getDebugLoc(), TII->get(MaddOpc),
ResultReg)
.addReg(SrcReg0, getKillRegState(Src0IsKill))
.addReg(SrcReg1, getKillRegState(Src1IsKill))
.addReg(VR); .addReg(VR);
// Insert the MADD // Insert the MADD
InsInstrs.push_back(MIB); InsInstrs.push_back(MIB);
return MUL; return MUL;
} }
/// genAlternativeCodeSequence - when hasPattern() finds a pattern /// genAlternativeCodeSequence - when hasPattern() finds a pattern
/// this function generates the instructions that could replace the /// this function generates the instructions that could replace the
/// original code sequence /// original code sequence
@ -2494,6 +2521,7 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
const TargetInstrInfo *TII = MF.getTarget().getSubtargetImpl()->getInstrInfo(); const TargetInstrInfo *TII = MF.getTarget().getSubtargetImpl()->getInstrInfo();
MachineInstr *MUL; MachineInstr *MUL;
const TargetRegisterClass *RC = nullptr;
unsigned Opc; unsigned Opc;
switch (Pattern) { switch (Pattern) {
default: default:
@ -2507,7 +2535,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
// --- Create(MADD); // --- Create(MADD);
Opc = Pattern == MachineCombinerPattern::MC_MULADDW_OP1 ? AArch64::MADDWrrr Opc = Pattern == MachineCombinerPattern::MC_MULADDW_OP1 ? AArch64::MADDWrrr
: AArch64::MADDXrrr; : AArch64::MADDXrrr;
MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 1, Opc); if (Pattern == MachineCombinerPattern::MC_MULADDW_OP1)
RC = &AArch64::GPR32RegClass;
else
RC = &AArch64::GPR64RegClass;
MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC);
break; break;
case MachineCombinerPattern::MC_MULADDW_OP2: case MachineCombinerPattern::MC_MULADDW_OP2:
case MachineCombinerPattern::MC_MULADDX_OP2: case MachineCombinerPattern::MC_MULADDX_OP2:
@ -2517,52 +2549,56 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
// --- Create(MADD); // --- Create(MADD);
Opc = Pattern == MachineCombinerPattern::MC_MULADDW_OP2 ? AArch64::MADDWrrr Opc = Pattern == MachineCombinerPattern::MC_MULADDW_OP2 ? AArch64::MADDWrrr
: AArch64::MADDXrrr; : AArch64::MADDXrrr;
MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 2, Opc); if (Pattern == MachineCombinerPattern::MC_MULADDW_OP2)
RC = &AArch64::GPR32RegClass;
else
RC = &AArch64::GPR64RegClass;
MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
break; break;
case MachineCombinerPattern::MC_MULADDWI_OP1: case MachineCombinerPattern::MC_MULADDWI_OP1:
case MachineCombinerPattern::MC_MULADDXI_OP1: case MachineCombinerPattern::MC_MULADDXI_OP1: {
// MUL I=A,B,0 // MUL I=A,B,0
// ADD R,I,Imm // ADD R,I,Imm
// ==> ORR V, ZR, Imm // ==> ORR V, ZR, Imm
// ==> MADD R,A,B,V // ==> MADD R,A,B,V
// --- Create(MADD); // --- Create(MADD);
{ const TargetRegisterClass *OrrRC = nullptr;
const TargetRegisterClass *RC = unsigned BitSize, OrrOpc, ZeroReg;
MRI.getRegClass(Root.getOperand(1).getReg()); if (Pattern == MachineCombinerPattern::MC_MULADDWI_OP1) {
unsigned NewVR = MRI.createVirtualRegister(RC); OrrOpc = AArch64::ORRWri;
unsigned BitSize, OrrOpc, ZeroReg; OrrRC = &AArch64::GPR32spRegClass;
if (Pattern == MachineCombinerPattern::MC_MULADDWI_OP1) { BitSize = 32;
BitSize = 32; ZeroReg = AArch64::WZR;
OrrOpc = AArch64::ORRWri; Opc = AArch64::MADDWrrr;
ZeroReg = AArch64::WZR; RC = &AArch64::GPR32RegClass;
Opc = AArch64::MADDWrrr; } else {
} else { OrrOpc = AArch64::ORRXri;
OrrOpc = AArch64::ORRXri; OrrRC = &AArch64::GPR64spRegClass;
BitSize = 64; BitSize = 64;
ZeroReg = AArch64::XZR; ZeroReg = AArch64::XZR;
Opc = AArch64::MADDXrrr; Opc = AArch64::MADDXrrr;
} RC = &AArch64::GPR64RegClass;
uint64_t Imm = Root.getOperand(2).getImm(); }
unsigned NewVR = MRI.createVirtualRegister(OrrRC);
uint64_t Imm = Root.getOperand(2).getImm();
if (Root.getOperand(3).isImm()) { if (Root.getOperand(3).isImm()) {
unsigned val = Root.getOperand(3).getImm(); unsigned Val = Root.getOperand(3).getImm();
Imm = Imm << val; Imm = Imm << Val;
} }
uint64_t UImm = Imm << (64 - BitSize) >> (64 - BitSize); uint64_t UImm = Imm << (64 - BitSize) >> (64 - BitSize);
uint64_t Encoding; uint64_t Encoding;
if (AArch64_AM::processLogicalImmediate(UImm, BitSize, Encoding)) {
if (AArch64_AM::processLogicalImmediate(UImm, BitSize, Encoding)) { MachineInstrBuilder MIB1 =
MachineInstrBuilder MIB1 = BuildMI(MF, Root.getDebugLoc(), TII->get(OrrOpc), NewVR)
BuildMI(MF, Root.getDebugLoc(), TII->get(OrrOpc)) .addReg(ZeroReg)
.addOperand(MachineOperand::CreateReg(NewVR, true)) .addImm(Encoding);
.addReg(ZeroReg) InsInstrs.push_back(MIB1);
.addImm(Encoding); InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0));
InsInstrs.push_back(MIB1); MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR, RC);
InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0));
MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR);
}
} }
break; break;
}
case MachineCombinerPattern::MC_MULSUBW_OP1: case MachineCombinerPattern::MC_MULSUBW_OP1:
case MachineCombinerPattern::MC_MULSUBX_OP1: { case MachineCombinerPattern::MC_MULSUBX_OP1: {
// MUL I=A,B,0 // MUL I=A,B,0
@ -2570,29 +2606,32 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
// ==> SUB V, 0, C // ==> SUB V, 0, C
// ==> MADD R,A,B,V // = -C + A*B // ==> MADD R,A,B,V // = -C + A*B
// --- Create(MADD); // --- Create(MADD);
const TargetRegisterClass *RC = const TargetRegisterClass *SubRC = nullptr;
MRI.getRegClass(Root.getOperand(1).getReg());
unsigned NewVR = MRI.createVirtualRegister(RC);
unsigned SubOpc, ZeroReg; unsigned SubOpc, ZeroReg;
if (Pattern == MachineCombinerPattern::MC_MULSUBW_OP1) { if (Pattern == MachineCombinerPattern::MC_MULSUBW_OP1) {
SubOpc = AArch64::SUBWrr; SubOpc = AArch64::SUBWrr;
SubRC = &AArch64::GPR32spRegClass;
ZeroReg = AArch64::WZR; ZeroReg = AArch64::WZR;
Opc = AArch64::MADDWrrr; Opc = AArch64::MADDWrrr;
RC = &AArch64::GPR32RegClass;
} else { } else {
SubOpc = AArch64::SUBXrr; SubOpc = AArch64::SUBXrr;
SubRC = &AArch64::GPR64spRegClass;
ZeroReg = AArch64::XZR; ZeroReg = AArch64::XZR;
Opc = AArch64::MADDXrrr; Opc = AArch64::MADDXrrr;
RC = &AArch64::GPR64RegClass;
} }
unsigned NewVR = MRI.createVirtualRegister(SubRC);
// SUB NewVR, 0, C // SUB NewVR, 0, C
MachineInstrBuilder MIB1 = MachineInstrBuilder MIB1 =
BuildMI(MF, Root.getDebugLoc(), TII->get(SubOpc)) BuildMI(MF, Root.getDebugLoc(), TII->get(SubOpc), NewVR)
.addOperand(MachineOperand::CreateReg(NewVR, true))
.addReg(ZeroReg) .addReg(ZeroReg)
.addOperand(Root.getOperand(2)); .addOperand(Root.getOperand(2));
InsInstrs.push_back(MIB1); InsInstrs.push_back(MIB1);
InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0)); InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0));
MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR); MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR, RC);
} break; break;
}
case MachineCombinerPattern::MC_MULSUBW_OP2: case MachineCombinerPattern::MC_MULSUBW_OP2:
case MachineCombinerPattern::MC_MULSUBX_OP2: case MachineCombinerPattern::MC_MULSUBX_OP2:
// MUL I=A,B,0 // MUL I=A,B,0
@ -2601,7 +2640,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
// --- Create(MSUB); // --- Create(MSUB);
Opc = Pattern == MachineCombinerPattern::MC_MULSUBW_OP2 ? AArch64::MSUBWrrr Opc = Pattern == MachineCombinerPattern::MC_MULSUBW_OP2 ? AArch64::MSUBWrrr
: AArch64::MSUBXrrr; : AArch64::MSUBXrrr;
MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 2, Opc); if (Pattern == MachineCombinerPattern::MC_MULSUBW_OP2)
RC = &AArch64::GPR32RegClass;
else
RC = &AArch64::GPR64RegClass;
MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC);
break; break;
case MachineCombinerPattern::MC_MULSUBWI_OP1: case MachineCombinerPattern::MC_MULSUBWI_OP1:
case MachineCombinerPattern::MC_MULSUBXI_OP1: { case MachineCombinerPattern::MC_MULSUBXI_OP1: {
@ -2610,40 +2653,43 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
// ==> ORR V, ZR, -Imm // ==> ORR V, ZR, -Imm
// ==> MADD R,A,B,V // = -Imm + A*B // ==> MADD R,A,B,V // = -Imm + A*B
// --- Create(MADD); // --- Create(MADD);
const TargetRegisterClass *RC = const TargetRegisterClass *OrrRC = nullptr;
MRI.getRegClass(Root.getOperand(1).getReg());
unsigned NewVR = MRI.createVirtualRegister(RC);
unsigned BitSize, OrrOpc, ZeroReg; unsigned BitSize, OrrOpc, ZeroReg;
if (Pattern == MachineCombinerPattern::MC_MULSUBWI_OP1) { if (Pattern == MachineCombinerPattern::MC_MULSUBWI_OP1) {
BitSize = 32;
OrrOpc = AArch64::ORRWri; OrrOpc = AArch64::ORRWri;
RC = &AArch64::GPR32spRegClass;
BitSize = 32;
ZeroReg = AArch64::WZR; ZeroReg = AArch64::WZR;
Opc = AArch64::MADDWrrr; Opc = AArch64::MADDWrrr;
RC = &AArch64::GPR32RegClass;
} else { } else {
OrrOpc = AArch64::ORRXri; OrrOpc = AArch64::ORRXri;
RC = &AArch64::GPR64RegClass;
BitSize = 64; BitSize = 64;
ZeroReg = AArch64::XZR; ZeroReg = AArch64::XZR;
Opc = AArch64::MADDXrrr; Opc = AArch64::MADDXrrr;
RC = &AArch64::GPR64RegClass;
} }
unsigned NewVR = MRI.createVirtualRegister(OrrRC);
int Imm = Root.getOperand(2).getImm(); int Imm = Root.getOperand(2).getImm();
if (Root.getOperand(3).isImm()) { if (Root.getOperand(3).isImm()) {
unsigned val = Root.getOperand(3).getImm(); unsigned Val = Root.getOperand(3).getImm();
Imm = Imm << val; Imm = Imm << Val;
} }
uint64_t UImm = -Imm << (64 - BitSize) >> (64 - BitSize); uint64_t UImm = -Imm << (64 - BitSize) >> (64 - BitSize);
uint64_t Encoding; uint64_t Encoding;
if (AArch64_AM::processLogicalImmediate(UImm, BitSize, Encoding)) { if (AArch64_AM::processLogicalImmediate(UImm, BitSize, Encoding)) {
MachineInstrBuilder MIB1 = MachineInstrBuilder MIB1 =
BuildMI(MF, Root.getDebugLoc(), TII->get(OrrOpc)) BuildMI(MF, Root.getDebugLoc(), TII->get(OrrOpc), NewVR)
.addOperand(MachineOperand::CreateReg(NewVR, true))
.addReg(ZeroReg) .addReg(ZeroReg)
.addImm(Encoding); .addImm(Encoding);
InsInstrs.push_back(MIB1); InsInstrs.push_back(MIB1);
InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0)); InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0));
MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR); MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR, RC);
} }
} break; break;
} }
} // end switch (Pattern)
// Record MUL and ADD/SUB for deletion // Record MUL and ADD/SUB for deletion
DelInstrs.push_back(MUL); DelInstrs.push_back(MUL);
DelInstrs.push_back(&Root); DelInstrs.push_back(&Root);

View File

@ -0,0 +1,20 @@
; RUN: llc -mtriple=aarch64-apple-darwin -verify-machineinstrs < %s | FileCheck %s
; Test that we use the correct register class.
define i32 @mul_add_imm(i32 %a, i32 %b) {
; CHECK-LABEL: mul_add_imm
; CHECK: orr [[REG:w[0-9]+]], wzr, #0x4
; CHECK-NEXT: madd {{w[0-9]+}}, w0, w1, [[REG]]
%1 = mul i32 %a, %b
%2 = add i32 %1, 4
ret i32 %2
}
define i32 @mul_sub_imm1(i32 %a, i32 %b) {
; CHECK-LABEL: mul_sub_imm1
; CHECK: orr [[REG:w[0-9]+]], wzr, #0x4
; CHECK-NEXT: msub {{w[0-9]+}}, w0, w1, [[REG]]
%1 = mul i32 %a, %b
%2 = sub i32 4, %1
ret i32 %2
}