Unverified Commit 8e77390c authored by Shengchen Kan's avatar Shengchen Kan Committed by GitHub
Browse files

[X86][CodeGen] Support folding memory broadcast in X86InstrInfo::foldMemoryOperandImpl (#79761)

parent c12f30c7
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -1067,7 +1067,7 @@ multiclass avx512_broadcast_rm_split<bits<8> opc, string OpcodeStr,
                        MaskInfo.RC:$src0))],
                      DestInfo.ExeDomain>, T8, PD, EVEX, EVEX_K, Sched<[SchedRR]>;
  let hasSideEffects = 0, mayLoad = 1 in
  let hasSideEffects = 0, mayLoad = 1, isReMaterializable = 1, canFoldAsLoad = 1 in
  def rm : AVX512PI<opc, MRMSrcMem, (outs MaskInfo.RC:$dst),
                    (ins SrcInfo.ScalarMemOp:$src),
                    !strconcat(OpcodeStr, "\t{$src, $dst|$dst, $src}"),
+19 −2
Original line number Diff line number Diff line
@@ -143,6 +143,23 @@ const X86FoldTableEntry *llvm::lookupFoldTable(unsigned RegOp, unsigned OpNum) {
  return lookupFoldTableImpl(FoldTable, RegOp);
}

const X86FoldTableEntry *llvm::lookupBroadcastFoldTable(unsigned RegOp,
                                                        unsigned OpNum) {
  ArrayRef<X86FoldTableEntry> FoldTable;
  if (OpNum == 1)
    FoldTable = ArrayRef(BroadcastTable1);
  else if (OpNum == 2)
    FoldTable = ArrayRef(BroadcastTable2);
  else if (OpNum == 3)
    FoldTable = ArrayRef(BroadcastTable3);
  else if (OpNum == 4)
    FoldTable = ArrayRef(BroadcastTable4);
  else
    return nullptr;

  return lookupFoldTableImpl(FoldTable, RegOp);
}

namespace {

// This class stores the memory unfolding tables. It is instantiated as a
@@ -288,7 +305,7 @@ struct X86BroadcastFoldTable {
};
} // namespace

static bool matchBroadcastSize(const X86FoldTableEntry &Entry,
bool llvm::matchBroadcastSize(const X86FoldTableEntry &Entry,
                              unsigned BroadcastBits) {
  switch (Entry.Flags & TB_BCAST_MASK) {
  case TB_BCAST_W:
+6 −0
Original line number Diff line number Diff line
@@ -44,6 +44,11 @@ const X86FoldTableEntry *lookupTwoAddrFoldTable(unsigned RegOp);
// operand OpNum.
const X86FoldTableEntry *lookupFoldTable(unsigned RegOp, unsigned OpNum);

// Look up the broadcast folding table entry for folding a broadcast with
// operand OpNum.
const X86FoldTableEntry *lookupBroadcastFoldTable(unsigned RegOp,
                                                  unsigned OpNum);

// Look up the memory unfolding table entry for this instruction.
const X86FoldTableEntry *lookupUnfoldTable(unsigned MemOp);

@@ -52,6 +57,7 @@ const X86FoldTableEntry *lookupUnfoldTable(unsigned MemOp);
const X86FoldTableEntry *lookupBroadcastFoldTableBySize(unsigned MemOp,
                                                        unsigned BroadcastBits);

bool matchBroadcastSize(const X86FoldTableEntry &Entry, unsigned BroadcastBits);
} // namespace llvm

#endif
+86 −0
Original line number Diff line number Diff line
@@ -862,6 +862,28 @@ bool X86InstrInfo::isReallyTriviallyReMaterializable(
  case X86::MMX_MOVD64rm:
  case X86::MMX_MOVQ64rm:
  // AVX-512
  case X86::VPBROADCASTBZ128rm:
  case X86::VPBROADCASTBZ256rm:
  case X86::VPBROADCASTBZrm:
  case X86::VBROADCASTF32X2Z256rm:
  case X86::VBROADCASTF32X2Zrm:
  case X86::VBROADCASTI32X2Z128rm:
  case X86::VBROADCASTI32X2Z256rm:
  case X86::VBROADCASTI32X2Zrm:
  case X86::VPBROADCASTWZ128rm:
  case X86::VPBROADCASTWZ256rm:
  case X86::VPBROADCASTWZrm:
  case X86::VPBROADCASTDZ128rm:
  case X86::VPBROADCASTDZ256rm:
  case X86::VPBROADCASTDZrm:
  case X86::VBROADCASTSSZ128rm:
  case X86::VBROADCASTSSZ256rm:
  case X86::VBROADCASTSSZrm:
  case X86::VPBROADCASTQZ128rm:
  case X86::VPBROADCASTQZ256rm:
  case X86::VPBROADCASTQZrm:
  case X86::VBROADCASTSDZ256rm:
  case X86::VBROADCASTSDZrm:
  case X86::VMOVSSZrm:
  case X86::VMOVSSZrm_alt:
  case X86::VMOVSDZrm:
@@ -8067,6 +8089,39 @@ MachineInstr *X86InstrInfo::foldMemoryOperandImpl(
    MOs.push_back(MachineOperand::CreateReg(0, false));
    break;
  }
  case X86::VPBROADCASTBZ128rm:
  case X86::VPBROADCASTBZ256rm:
  case X86::VPBROADCASTBZrm:
  case X86::VBROADCASTF32X2Z256rm:
  case X86::VBROADCASTF32X2Zrm:
  case X86::VBROADCASTI32X2Z128rm:
  case X86::VBROADCASTI32X2Z256rm:
  case X86::VBROADCASTI32X2Zrm:
    // No instructions currently fuse with 8bits or 32bits x 2.
    return nullptr;

#define FOLD_BROADCAST(SIZE)                                                   \
  MOs.append(LoadMI.operands_begin() + NumOps - X86::AddrNumOperands,          \
             LoadMI.operands_begin() + NumOps);                                \
  return foldMemoryBroadcast(MF, MI, Ops[0], MOs, InsertPt, /*Size=*/SIZE,     \
                             /*AllowCommute=*/true);
  case X86::VPBROADCASTWZ128rm:
  case X86::VPBROADCASTWZ256rm:
  case X86::VPBROADCASTWZrm:
    FOLD_BROADCAST(16);
  case X86::VPBROADCASTDZ128rm:
  case X86::VPBROADCASTDZ256rm:
  case X86::VPBROADCASTDZrm:
  case X86::VBROADCASTSSZ128rm:
  case X86::VBROADCASTSSZ256rm:
  case X86::VBROADCASTSSZrm:
    FOLD_BROADCAST(32);
  case X86::VPBROADCASTQZ128rm:
  case X86::VPBROADCASTQZ256rm:
  case X86::VPBROADCASTQZrm:
  case X86::VBROADCASTSDZ256rm:
  case X86::VBROADCASTSDZrm:
    FOLD_BROADCAST(64);
  default: {
    if (isNonFoldablePartialRegisterLoad(LoadMI, MI, MF))
      return nullptr;
@@ -8081,6 +8136,37 @@ MachineInstr *X86InstrInfo::foldMemoryOperandImpl(
                               /*Size=*/0, Alignment, /*AllowCommute=*/true);
}

MachineInstr *
X86InstrInfo::foldMemoryBroadcast(MachineFunction &MF, MachineInstr &MI,
                                  unsigned OpNum, ArrayRef<MachineOperand> MOs,
                                  MachineBasicBlock::iterator InsertPt,
                                  unsigned BitsSize, bool AllowCommute) const {

  if (auto *I = lookupBroadcastFoldTable(MI.getOpcode(), OpNum))
    return matchBroadcastSize(*I, BitsSize)
               ? FuseInst(MF, I->DstOp, OpNum, MOs, InsertPt, MI, *this)
               : nullptr;

  if (AllowCommute) {
    // If the instruction and target operand are commutable, commute the
    // instruction and try again.
    unsigned CommuteOpIdx2 = commuteOperandsForFold(MI, OpNum);
    if (CommuteOpIdx2 == OpNum) {
      printFailMsgforFold(MI, OpNum);
      return nullptr;
    }
    MachineInstr *NewMI =
        foldMemoryBroadcast(MF, MI, CommuteOpIdx2, MOs, InsertPt, BitsSize,
                            /*AllowCommute=*/false);
    if (NewMI)
      return NewMI;
    UndoCommuteForFold(MI, OpNum, CommuteOpIdx2);
  }

  printFailMsgforFold(MI, OpNum);
  return nullptr;
}

static SmallVector<MachineMemOperand *, 2>
extractLoadMMOs(ArrayRef<MachineMemOperand *> MMOs, MachineFunction &MF) {
  SmallVector<MachineMemOperand *, 2> LoadMMOs;
+6 −0
Original line number Diff line number Diff line
@@ -643,6 +643,12 @@ private:
                                        MachineBasicBlock::iterator InsertPt,
                                        unsigned Size, Align Alignment) const;

  MachineInstr *foldMemoryBroadcast(MachineFunction &MF, MachineInstr &MI,
                                    unsigned OpNum,
                                    ArrayRef<MachineOperand> MOs,
                                    MachineBasicBlock::iterator InsertPt,
                                    unsigned BitsSize, bool AllowCommute) const;

  /// isFrameOperand - Return true and the FrameIndex if the specified
  /// operand and follow operands form a reference to the stack frame.
  bool isFrameOperand(const MachineInstr &MI, unsigned int Op,
Loading